From: Ron Dempster (rdempste) Date: Tue, 22 Mar 2022 19:06:38 +0000 (+0000) Subject: Pull request #3279: Multi-tenant with reconcile inspectors and reputation with reload... X-Git-Tag: 3.1.26.0~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=79faa2fea69149d5ddcfc93128bb93aff7a0ede1;p=thirdparty%2Fsnort3.git Pull request #3279: Multi-tenant with reconcile inspectors and reputation with reload command Merge in SNORT/snort3 from ~RDEMPSTE/snort3:reputation to master Squashed commit of the following: commit fb9b349ce3fc2612c4f0bdae6f1e03a511bf9cf7 Author: Ron Dempster (rdempste) Date: Tue Mar 22 11:06:13 2022 -0400 framework: update base API version to 13 commit 877c1e7dcc63499301a8868880831b27ff9bcabe Author: Ron Dempster (rdempste) Date: Fri Mar 11 07:32:55 2022 -0500 appid: sum stats at tterm and null the thread local stats pointer after delete commit d23843bb934a4072c1c15458f9ddf17a95d1d269 Author: Ron Dempster (rdempste) Date: Tue Mar 8 10:16:45 2022 -0500 main: add the control connection to the analyzer command and a method to log a message to both console and the remote connection commit aaf890c670f013e8af21c8db345139314084d13e Author: Ron Dempster (rdempste) Date: Sat Mar 5 13:18:39 2022 -0500 main: fix and reenable the distill_verdict unit test commit edc81969f10a390a4a1e6e355906566405778583 Author: Ron Dempster (rdempste) Date: Tue Mar 8 09:37:46 2022 -0500 managers: add get_inspector unit tests commit 393507e0e4182033f7f726e710516ffc68e95d1d Author: Ron Dempster (rdempste) Date: Fri Feb 25 12:22:24 2022 -0500 policy_selectors: add a method to select policies based on DAQ_FlowStats_t commit c85bb3a7b2225efda3e0ade20267746a989f7e01 Author: Ron Dempster (rdempste) Date: Mon Feb 14 12:39:59 2022 -0500 appid: make appid a global inspector commit 046846e765831debe98886fdf1ce57382db96c75 Author: Ron Dempster (rdempste) Date: Fri Feb 11 10:12:40 2022 -0500 managers: add a faster get_inspectors method commit 3470d1cb7dfdee60af067f15bba29694e4646ed3 Author: Ron Dempster (rdempste) Date: Fri Jan 14 10:22:17 2022 -0500 inspector, main, inspector_manager: add support for thread local data in inspectors and commands updating reload_id commit 3d9c2556dbb39220ca26d61e4f2e6e2477b55a22 Author: Ron Dempster (rdempste) Date: Tue Dec 7 15:43:49 2021 -0500 reputation: add a command to reload repuation data commit c74d98a34b089d0b86db78cac78c6aaa793c2853 Author: Ron Dempster (rdempste) Date: Tue Dec 21 08:22:14 2021 -0500 flow: make service a shared pointer to handle reload properly commit 6750746d83d0c82ff3ebe552be43f8d36797c29b Author: Ron Dempster (rdempste) Date: Thu Dec 16 07:59:30 2021 -0500 managers: move inspection policies into the corresponding network policy --- diff --git a/src/control/control.h b/src/control/control.h index 6bc60750d..a211b46a5 100644 --- a/src/control/control.h +++ b/src/control/control.h @@ -61,11 +61,11 @@ public: void shutdown(); SO_PUBLIC bool is_local() const { return local; } + SO_PUBLIC bool respond(const char* format, va_list& ap); SO_PUBLIC bool respond(const char* format, ...) __attribute__((format (printf, 2, 3))); SO_PUBLIC static ControlConn* query_from_lua(const lua_State*); private: - bool respond(const char* format, va_list& ap); bool show_prompt(); void touch(); diff --git a/src/flow/flow.cc b/src/flow/flow.cc index dd30e5052..916a5b0d9 100644 --- a/src/flow/flow.cc +++ b/src/flow/flow.cc @@ -24,6 +24,7 @@ #include "flow.h" #include "detection/detection_engine.h" +#include "flow/flow_key.h" #include "flow/ha.h" #include "flow/session.h" #include "framework/data_bus.h" @@ -118,6 +119,8 @@ void Flow::term() delete stash; stash = nullptr; } + + service.reset(); } inline void Flow::clean() @@ -211,6 +214,7 @@ void Flow::reset(bool do_cleanup) stash->reset(); deferred_trust.clear(); + service.reset(); constexpr size_t offset = offsetof(Flow, context_chain); // FIXIT-L need a struct to zero here to make future proof @@ -326,12 +330,33 @@ void Flow::free_flow_data(uint32_t proto) void Flow::free_flow_data() { + NetworkPolicy* np = get_network_policy(); + InspectionPolicy* ip = get_inspection_policy(); + IpsPolicy* ipsp = get_ips_policy(); + + unsigned t_reload_id = SnortConfig::get_thread_reload_id(); + if (reload_id == t_reload_id) + { + ::set_network_policy(network_policy_id); + ::set_inspection_policy(inspection_policy_id); + ::set_ips_policy(SnortConfig::get_conf(), ips_policy_id); + } + else + { + _daq_pkt_hdr pkthdr = {}; + pkthdr.address_space_id = key->addressSpaceId; + select_default_policy(pkthdr, SnortConfig::get_conf()); + } while (flow_data) { FlowData* tmp = flow_data; flow_data = flow_data->next; delete tmp; } + + set_network_policy(np); + set_inspection_policy(ip); + set_ips_policy(ipsp); } void Flow::call_handlers(Packet* p, bool eof) @@ -546,12 +571,21 @@ bool Flow::is_direction_aborted(bool from_client) const return (session_flags & SSNFLAG_ABORT_CLIENT); } -void Flow::set_service(Packet* pkt, const char* new_service) +void Flow::set_service(Packet* pkt, std::shared_ptr new_service) { + if (!new_service.use_count()) + return clear_service(pkt); + service = new_service; DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt); } +void Flow::clear_service(Packet* pkt) +{ + service.reset(); + DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt); +} + void Flow::swap_roles() { std::swap(flowstats.client_pkts, flowstats.server_pkts); diff --git a/src/flow/flow.h b/src/flow/flow.h index bce283bb3..9b13f8d75 100644 --- a/src/flow/flow.h +++ b/src/flow/flow.h @@ -27,9 +27,12 @@ // state. Inspector state is stored in FlowData, and Flow manages a list // of FlowData items. -#include +#include +#include #include +#include + #include "detection/ips_context_chain.h" #include "flow/deferred_trust.h" #include "flow/flow_data.h" @@ -197,7 +200,10 @@ public: void set_mpls_layer_per_dir(Packet*); Layer get_mpls_layer_per_dir(bool); void swap_roles(); - void set_service(Packet* pkt, const char* new_service); + void set_service(Packet*, std::shared_ptr new_service); + void clear_service(Packet*); + bool has_service() const + { return 0 != service.use_count(); } bool get_attr(const std::string& key, int32_t& val); bool get_attr(const std::string& key, std::string& val); void set_attr(const std::string& key, const int32_t& val); @@ -281,26 +287,20 @@ public: void set_client(Inspector* ins) { + if (ssn_client) + ssn_client->rem_ref(); ssn_client = ins; - ssn_client->add_ref(); - } - - void clear_client() - { - ssn_client->rem_ref(); - ssn_client = nullptr; + if (ssn_client) + ssn_client->add_ref(); } void set_server(Inspector* ins) { + if (ssn_server) + ssn_server->rem_ref(); ssn_server = ins; - ssn_server->add_ref(); - } - - void clear_server() - { - ssn_server->rem_ref(); - ssn_server = nullptr; + if (ssn_server) + ssn_server->add_ref(); } void set_clouseau(Inspector* ins) @@ -405,8 +405,8 @@ public: // FIXIT-M privatize if possible // fields are organized by initialization and size to minimize // void space and allow for memset of tail end of struct - // these fields are const after initialization DeferredTrust deferred_trust; + std::shared_ptr service; // Anything before this comment is not zeroed during construction const FlowKey* key; @@ -441,10 +441,10 @@ public: // FIXIT-M privatize if possible Inspector* gadget; // service handler Inspector* assistant_gadget; Inspector* data; - const char* service; uint64_t expire_time; + unsigned network_policy_id; unsigned inspection_policy_id; unsigned ips_policy_id; unsigned reload_id; diff --git a/src/flow/flow_control.cc b/src/flow/flow_control.cc index 611f5c9a9..6324ca993 100644 --- a/src/flow/flow_control.cc +++ b/src/flow/flow_control.cc @@ -435,14 +435,24 @@ unsigned FlowControl::process(Flow* flow, Packet* p) if ( flow->flow_state != Flow::FlowState::SETUP ) { - const SnortConfig* sc = SnortConfig::get_conf(); - set_inspection_policy(sc, flow->inspection_policy_id); - set_ips_policy(sc, flow->ips_policy_id); + unsigned reload_id = SnortConfig::get_thread_reload_id(); + if (flow->reload_id != reload_id) + { + flow->network_policy_id = get_network_policy()->policy_id; + if (flow->flow_state == Flow::FlowState::INSPECT) + DataBus::publish(FLOW_STATE_RELOADED_EVENT, p, flow); + } + else + { + set_inspection_policy(flow->inspection_policy_id); + set_ips_policy(p->context->conf, flow->ips_policy_id); + } p->filtering_state = flow->filtering_state; } else { + flow->network_policy_id = get_network_policy()->policy_id; if (PacketTracer::is_active()) PacketTracer::log("Session: new snort session\n"); diff --git a/src/flow/test/flow_cache_test.cc b/src/flow/test/flow_cache_test.cc index aa9af717b..2be521139 100644 --- a/src/flow/test/flow_cache_test.cc +++ b/src/flow/test/flow_cache_test.cc @@ -27,6 +27,7 @@ #include "flow/flow_control.h" #include "detection/detection_engine.h" +#include "main/policy.h" #include "main/snort_config.h" #include "main/snort_debug.h" #include "managers/inspector_manager.h" @@ -67,7 +68,12 @@ void Active::set_drop_reason(char const*) { } Packet::Packet(bool) { } Packet::~Packet() = default; uint32_t Packet::get_flow_geneve_vni() const { return 0; } -Flow::Flow() { memset(this, 0, sizeof(*this)); } +Flow::Flow() +{ + constexpr size_t offset = offsetof(Flow, key); + // FIXIT-L need a struct to zero here to make future proof + memset((uint8_t*)this+offset, 0, sizeof(*this)-offset); +} Flow::~Flow() = default; DetectionEngine::DetectionEngine() = default; ExpectCache::~ExpectCache() = default; @@ -77,14 +83,14 @@ void Flow::term() { } void Flow::flush(bool) { } void Flow::reset(bool) { } void Flow::free_flow_data() { } -void set_network_policy(const SnortConfig*, unsigned) { } void DataBus::publish(const char*, const uint8_t*, unsigned, Flow*) { } void DataBus::publish(const char*, Packet*, Flow*) { } const SnortConfig* SnortConfig::get_conf() { return nullptr; } void Flow::set_client_initiate(Packet*) { } void Flow::set_direction(Packet*) { } -void set_inspection_policy(const SnortConfig*, unsigned) { } -void set_ips_policy(const SnortConfig*, unsigned) { } +void set_network_policy(unsigned) { } +void set_inspection_policy(unsigned) { } +void set_ips_policy(const snort::SnortConfig*, unsigned) { } void Flow::set_mpls_layer_per_dir(Packet*) { } void DetectionEngine::disable_all(Packet*) { } void Stream::drop_traffic(const Packet*, char) { } @@ -101,6 +107,11 @@ void snort::TraceApi::filter(const Packet&) {} namespace snort { +NetworkPolicy* get_network_policy() { return nullptr; } +InspectionPolicy* get_inspection_policy() { return nullptr; } +IpsPolicy* get_ips_policy() { return nullptr; } +unsigned SnortConfig::get_thread_reload_id() { return 0; } + namespace layer { const vlan::VlanTagHdr* get_vlan_layer(const Packet* const) { return nullptr; } diff --git a/src/flow/test/flow_control_test.cc b/src/flow/test/flow_control_test.cc index 6174ed748..6512e53ac 100644 --- a/src/flow/test/flow_control_test.cc +++ b/src/flow/test/flow_control_test.cc @@ -27,6 +27,7 @@ #include "flow/flow_control.h" #include "detection/detection_engine.h" +#include "main/policy.h" #include "main/snort_config.h" #include "managers/inspector_manager.h" #include "packet_io/active.h" @@ -81,15 +82,15 @@ bool FlowCache::prune_one(PruneReason, bool) { return true; } unsigned FlowCache::delete_flows(unsigned) { return 0; } unsigned FlowCache::timeout(unsigned, time_t) { return 1; } void Flow::init(PktType) { } -void set_network_policy(const SnortConfig*, unsigned) { } void DataBus::publish(const char*, const uint8_t*, unsigned, Flow*) { } void DataBus::publish(const char*, Packet*, Flow*) { } const SnortConfig* SnortConfig::get_conf() { return nullptr; } void FlowCache::unlink_uni(Flow*) { } void Flow::set_client_initiate(Packet*) { } void Flow::set_direction(Packet*) { } -void set_inspection_policy(const SnortConfig*, unsigned) { } -void set_ips_policy(const SnortConfig*, unsigned) { } +void set_network_policy(unsigned) { } +void set_inspection_policy(unsigned) { } +void set_ips_policy(const snort::SnortConfig*, unsigned) { } void Flow::set_mpls_layer_per_dir(Packet*) { } void DetectionEngine::disable_all(Packet*) { } void Stream::drop_traffic(const Packet*, char) { } @@ -101,6 +102,11 @@ Flow* HighAvailabilityManager::import(Packet&, FlowKey&) { return nullptr; } namespace snort { +NetworkPolicy* get_network_policy() { return nullptr; } +InspectionPolicy* get_inspection_policy() { return nullptr; } +IpsPolicy* get_ips_policy() { return nullptr; } +unsigned SnortConfig::get_thread_reload_id() { return 0; } + namespace layer { const vlan::VlanTagHdr* get_vlan_layer(const Packet* const) { return nullptr; } diff --git a/src/flow/test/flow_stash_test.cc b/src/flow/test/flow_stash_test.cc index 0bffa3979..9b7c64a42 100644 --- a/src/flow/test/flow_stash_test.cc +++ b/src/flow/test/flow_stash_test.cc @@ -111,7 +111,7 @@ void DataBus::_subscribe(const char* key, DataHandler* h) void DataBus::_unsubscribe(const char*, DataHandler*) {} -void DataBus::_publish(const char* key, DataEvent& e, Flow* f) +void DataBus::_publish(const char* key, DataEvent& e, Flow* f) const { auto v = map.find(key); diff --git a/src/flow/test/flow_test.cc b/src/flow/test/flow_test.cc index eace4f931..77fe3c8e5 100644 --- a/src/flow/test/flow_test.cc +++ b/src/flow/test/flow_test.cc @@ -29,6 +29,7 @@ #include "flow/ha.h" #include "framework/inspector.h" #include "framework/data_bus.h" +#include "main/policy.h" #include "main/snort_config.h" #include "protocols/ip.h" #include "protocols/layer.h" @@ -58,6 +59,21 @@ void FlowStash::reset() {} void DetectionEngine::onload(Flow*) {} +void set_network_policy(unsigned) { } +void set_inspection_policy(unsigned) { } +void set_ips_policy(const snort::SnortConfig*, unsigned) { } +void select_default_policy(const _daq_pkt_hdr&, const SnortConfig*) { } +namespace snort +{ +NetworkPolicy* get_network_policy() { return nullptr; } +InspectionPolicy* get_inspection_policy() { return nullptr; } +IpsPolicy* get_ips_policy() { return nullptr; } +void set_network_policy(NetworkPolicy*) { } +void set_inspection_policy(InspectionPolicy*) { } +void set_ips_policy(IpsPolicy*) { } +unsigned SnortConfig::get_thread_reload_id() { return 0; } +} + Packet* DetectionEngine::set_next_packet(Packet*, Flow*) { return nullptr; } IpsContext* DetectionEngine::get_context() { return nullptr; } diff --git a/src/framework/base_api.h b/src/framework/base_api.h index b3d21230f..1b646581e 100644 --- a/src/framework/base_api.h +++ b/src/framework/base_api.h @@ -29,7 +29,7 @@ // this is the current version of the base api // must be prefixed to subtype version -#define BASE_API_VERSION 12 +#define BASE_API_VERSION 13 // set options to API_OPTIONS to ensure compatibility #ifndef API_OPTIONS diff --git a/src/framework/data_bus.cc b/src/framework/data_bus.cc index 12bea2593..debd6a880 100644 --- a/src/framework/data_bus.cc +++ b/src/framework/data_bus.cc @@ -107,6 +107,12 @@ void DataBus::subscribe_network(const char* key, DataHandler* h) get_network_data_bus()._subscribe(key, h); } +// for subscribers that need to receive events regardless of active inspection policy +void DataBus::subscribe_global(const char* key, DataHandler* h, SnortConfig& sc) +{ + sc.global_dbus->_subscribe(key, h); +} + void DataBus::unsubscribe(const char* key, DataHandler* h) { get_data_bus()._unsubscribe(key, h); @@ -117,9 +123,16 @@ void DataBus::unsubscribe_network(const char* key, DataHandler* h) get_network_data_bus()._unsubscribe(key, h); } +void DataBus::unsubscribe_global(const char* key, DataHandler* h, SnortConfig& sc) +{ + sc.global_dbus->_unsubscribe(key, h); +} + // notify subscribers of event void DataBus::publish(const char* key, DataEvent& e, Flow* f) { + SnortConfig::get_conf()->global_dbus->_publish(key, e, f); + NetworkPolicy* ni = get_network_policy(); ni->dbus._publish(key, e, f); @@ -176,7 +189,7 @@ void DataBus::_unsubscribe(const char* key, DataHandler* h) } // notify subscribers of event -void DataBus::_publish(const char* key, DataEvent& e, Flow* f) +void DataBus::_publish(const char* key, DataEvent& e, Flow* f) const { auto v = map.find(key); diff --git a/src/framework/data_bus.h b/src/framework/data_bus.h index 59d901e30..3b7f8c711 100644 --- a/src/framework/data_bus.h +++ b/src/framework/data_bus.h @@ -38,6 +38,7 @@ namespace snort { class Flow; struct Packet; +struct SnortConfig; class DataEvent { @@ -101,10 +102,12 @@ public: // FIXIT-L ideally these would not be static or would take an inspection policy* static void subscribe(const char* key, DataHandler*); static void subscribe_network(const char* key, DataHandler*); + static void subscribe_global(const char* key, DataHandler*, SnortConfig&); // FIXIT-L these should be called during cleanup static void unsubscribe(const char* key, DataHandler*); static void unsubscribe_network(const char* key, DataHandler*); + static void unsubscribe_global(const char* key, DataHandler*, SnortConfig&); // runtime methods static void publish(const char* key, DataEvent&, Flow* = nullptr); @@ -116,7 +119,7 @@ public: private: void _subscribe(const char* key, DataHandler*); void _unsubscribe(const char* key, DataHandler*); - void _publish(const char* key, DataEvent&, Flow*); + void _publish(const char* key, DataEvent&, Flow*) const; private: DataMap map; @@ -147,6 +150,9 @@ private: // A flow has entered the setup state #define FLOW_STATE_SETUP_EVENT "flow.state_setup" +// The policy has changed for the flow +#define FLOW_STATE_RELOADED_EVENT "flow.reloaded" + // A new flow is created on this packet #define STREAM_ICMP_NEW_FLOW_EVENT "stream.icmp_new_flow" #define STREAM_IP_NEW_FLOW_EVENT "stream.ip_new_flow" diff --git a/src/framework/inspector.cc b/src/framework/inspector.cc index 9d103c67c..3f905a711 100644 --- a/src/framework/inspector.cc +++ b/src/framework/inspector.cc @@ -27,6 +27,19 @@ #include "protocols/packet.h" #include "stream/stream_splitter.h" +namespace snort +{ +class ThreadSpecificData +{ +public: + explicit ThreadSpecificData(unsigned max) + { data.resize(max); } + ~ThreadSpecificData() = default; + + std::vector data; +}; +} + using namespace snort; //------------------------------------------------------------------------- @@ -34,14 +47,11 @@ using namespace snort; //------------------------------------------------------------------------- unsigned THREAD_LOCAL Inspector::slot = 0; -unsigned Inspector::max_slots = 1; Inspector::Inspector() { unsigned max = ThreadConfig::get_instance_max(); - assert(slot < max); ref_count = new std::atomic_uint[max]; - for ( unsigned i = 0; i < max; ++i ) ref_count[i] = 0; } @@ -125,6 +135,25 @@ void Inspector::add_global_ref() void Inspector::rem_global_ref() { --ref_count[0]; } +void Inspector::allocate_thread_storage() +{ + if (!thread_specific_data.use_count()) + thread_specific_data = std::make_shared(ThreadConfig::get_instance_max()); +} + +void Inspector::copy_thread_storage(Inspector* ins) +{ + assert(!thread_specific_data.use_count()); + if (ins->thread_specific_data.use_count()) + thread_specific_data = ins->thread_specific_data; +} + +void Inspector::set_thread_specific_data(void* tsd) +{ thread_specific_data->data[slot] = tsd; } + +void* Inspector::get_thread_specific_data() const +{ return thread_specific_data->data[slot]; } + static const char* InspectorTypeNames[IT_MAX] = { "passive", diff --git a/src/framework/inspector.h b/src/framework/inspector.h index 96d453c90..f598ecaf5 100644 --- a/src/framework/inspector.h +++ b/src/framework/inspector.h @@ -25,6 +25,9 @@ // in different ways. These correspond to Snort 2X preprocessors. #include +#include +#include +#include #include "framework/base_api.h" #include "main/thread.h" @@ -59,6 +62,8 @@ struct InspectApi; // api for class //------------------------------------------------------------------------- +class ThreadSpecificData; + class SO_PUBLIC Inspector { public: @@ -164,8 +169,15 @@ public: virtual bool can_start_tls() const { return false; } + void allocate_thread_storage(); + void set_thread_specific_data(void*); + void* get_thread_specific_data() const; + void copy_thread_storage(Inspector*); + + virtual void install_reload_handler(SnortConfig*) + { } + public: - static unsigned max_slots; static THREAD_LOCAL unsigned slot; protected: @@ -174,6 +186,7 @@ protected: private: const InspectApi* api = nullptr; + std::shared_ptr thread_specific_data; std::atomic_uint* ref_count; SnortProtocolId snort_protocol_id = 0; // FIXIT-E Use std::string to avoid storing a pointer to external std::string buffers diff --git a/src/framework/policy_selector.h b/src/framework/policy_selector.h index ab0e5e46f..b3c2a1d53 100644 --- a/src/framework/policy_selector.h +++ b/src/framework/policy_selector.h @@ -29,6 +29,7 @@ #include "framework/counts.h" #include "main/snort_types.h" +struct _daq_flow_stats; struct _daq_pkt_hdr; namespace snort @@ -81,7 +82,8 @@ public: } const PolicySelectorApi* get_api() const { return api; } - virtual bool select_default_policies(const _daq_pkt_hdr*, const SnortConfig*) = 0; + virtual bool select_default_policies(const _daq_pkt_hdr&, const SnortConfig*) = 0; + virtual bool select_default_policies(const _daq_flow_stats&, const SnortConfig*) = 0; virtual void show() const = 0; protected: diff --git a/src/framework/test/data_bus_test.cc b/src/framework/test/data_bus_test.cc index b9f1a9444..cd409cef3 100644 --- a/src/framework/test/data_bus_test.cc +++ b/src/framework/test/data_bus_test.cc @@ -41,7 +41,7 @@ NetworkPolicy::~NetworkPolicy() = default; namespace snort { SnortConfig::SnortConfig(snort::SnortConfig const*, const char*) -{ } +{ global_dbus = new DataBus(); } const SnortConfig* SnortConfig::get_conf() { @@ -58,7 +58,7 @@ SnortConfig* SnortConfig::get_main_conf() } SnortConfig::~SnortConfig() -{ } +{ delete global_dbus; } NetworkPolicy* get_network_policy() { @@ -135,6 +135,26 @@ TEST_GROUP(data_bus) } }; +TEST(data_bus, subscribe_global) +{ + UTestHandler h; + DataBus::subscribe_global(DB_UTEST_EVENT, &h, snort_conf); + + UTestEvent event(100); + DataBus::publish(DB_UTEST_EVENT, event); + CHECK(100 == h.evt_msg); + + UTestEvent event1(200); + DataBus::publish(DB_UTEST_EVENT, event1); + CHECK(200 == h.evt_msg); + + DataBus::unsubscribe_global(DB_UTEST_EVENT, &h, snort_conf); + + UTestEvent event2(300); + DataBus::publish(DB_UTEST_EVENT, event2); + CHECK(200 == h.evt_msg); // unsubscribed! +} + TEST(data_bus, subscribe_network) { UTestHandler* h = new UTestHandler(); diff --git a/src/hash/test/ghash_test.cc b/src/hash/test/ghash_test.cc index 808d0b82c..b207dd576 100644 --- a/src/hash/test/ghash_test.cc +++ b/src/hash/test/ghash_test.cc @@ -38,6 +38,9 @@ using namespace snort; static SnortConfig my_config; THREAD_LOCAL SnortConfig* snort_conf = &my_config; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash() SnortConfig::SnortConfig(const SnortConfig* const, const char*) { snort_conf->run_flags = 0;} diff --git a/src/hash/test/xhash_test.cc b/src/hash/test/xhash_test.cc index b4c4987c2..350e84afb 100644 --- a/src/hash/test/xhash_test.cc +++ b/src/hash/test/xhash_test.cc @@ -38,6 +38,9 @@ using namespace snort; static SnortConfig my_config; THREAD_LOCAL SnortConfig* snort_conf = &my_config; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash() SnortConfig::SnortConfig(const SnortConfig* const, const char*) { snort_conf->run_flags = 0;} diff --git a/src/hash/test/zhash_test.cc b/src/hash/test/zhash_test.cc index 8af9b44d9..0629283a8 100644 --- a/src/hash/test/zhash_test.cc +++ b/src/hash/test/zhash_test.cc @@ -63,6 +63,9 @@ bool FlowHashKeyOps::key_compare(const void* k1, const void* k2, size_t len) static SnortConfig my_config; THREAD_LOCAL SnortConfig *snort_conf = &my_config; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash() SnortConfig::SnortConfig(const SnortConfig* const, const char*) { snort_conf->run_flags = 0;} diff --git a/src/helpers/test/hyper_search_test.cc b/src/helpers/test/hyper_search_test.cc index 9634ba188..1fd57b154 100644 --- a/src/helpers/test/hyper_search_test.cc +++ b/src/helpers/test/hyper_search_test.cc @@ -50,6 +50,9 @@ static ScratchAllocator* scratcher = nullptr; static unsigned s_parse_errors = 0; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + SnortConfig::SnortConfig(const SnortConfig* const, const char*) { state = &s_state; diff --git a/src/host_tracker/host_cache_module.cc b/src/host_tracker/host_cache_module.cc index 44d4a62c3..cca05b71d 100644 --- a/src/host_tracker/host_cache_module.cc +++ b/src/host_tracker/host_cache_module.cc @@ -371,7 +371,7 @@ bool HostCacheModule::end(const char* fqn, int, SnortConfig* sc) if ( memcap and !strcmp(fqn, HOST_CACHE_NAME) ) { if ( Snort::is_reloading() ) - sc->register_reload_resource_tuner(new HostCacheReloadTuner(memcap)); + sc->register_reload_handler(new HostCacheReloadTuner(memcap)); else host_cache.set_max_size(memcap); } diff --git a/src/host_tracker/host_cache_module.h b/src/host_tracker/host_cache_module.h index 5144c9561..7fdc182aa 100644 --- a/src/host_tracker/host_cache_module.h +++ b/src/host_tracker/host_cache_module.h @@ -25,7 +25,7 @@ #include "framework/module.h" #include "main/snort.h" -#include "main/snort_config.h" +#include "main/reload_tuner.h" #include "host_cache.h" diff --git a/src/host_tracker/test/host_cache_module_test.cc b/src/host_tracker/test/host_cache_module_test.cc index 05ed3740b..3e46d2da6 100644 --- a/src/host_tracker/test/host_cache_module_test.cc +++ b/src/host_tracker/test/host_cache_module_test.cc @@ -72,7 +72,7 @@ void LogMessage(const char* format,...) } time_t packet_time() { return 0; } bool Snort::is_reloading() { return false; } -void SnortConfig::register_reload_resource_tuner(ReloadResourceTuner* rrt) { delete rrt; } +void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt) { delete rrt; } } // end of namespace snort void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } diff --git a/src/ips_options/test/ips_regex_test.cc b/src/ips_options/test/ips_regex_test.cc index e84e187fc..6f18b8128 100644 --- a/src/ips_options/test/ips_regex_test.cc +++ b/src/ips_options/test/ips_regex_test.cc @@ -56,6 +56,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static std::vector s_state; static ScratchAllocator* scratcher = nullptr; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + SnortConfig::SnortConfig(const SnortConfig* const, const char*) { state = &s_state; diff --git a/src/log/messages.cc b/src/log/messages.cc index b8c224d22..e621259fd 100644 --- a/src/log/messages.cc +++ b/src/log/messages.cc @@ -26,7 +26,6 @@ #include #include -#include #include #include "main/snort_config.h" @@ -189,6 +188,13 @@ static void WriteLogMessage(FILE* fh, bool prefer_fh, const char* format, va_lis } // print an info message to stdout or syslog +void LogMessage(const char* format, va_list& ap) +{ + if ( SnortConfig::log_quiet() ) + return; + WriteLogMessage(stdout, false, format, ap); +} + void LogMessage(const char* format,...) { if ( SnortConfig::log_quiet() ) diff --git a/src/log/messages.h b/src/log/messages.h index e2bdaf89b..9d8241ef4 100644 --- a/src/log/messages.h +++ b/src/log/messages.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -61,6 +62,7 @@ SO_PUBLIC void ParseError(const char*, ...) __attribute__((format (printf, 1, 2) SO_PUBLIC void ReloadError(const char*, ...) __attribute__((format (printf, 1, 2))); [[noreturn]] SO_PUBLIC void ParseAbort(const char*, ...) __attribute__((format (printf, 1, 2))); +SO_PUBLIC void LogMessage(const char*, va_list& ap); SO_PUBLIC void LogMessage(const char*, ...) __attribute__((format (printf, 1, 2))); SO_PUBLIC void LogMessage(FILE*, const char*, ...) __attribute__((format (printf, 2, 3))); SO_PUBLIC void WarningMessage(const char*, ...) __attribute__((format (printf, 1, 2))); diff --git a/src/loggers/alert_csv.cc b/src/loggers/alert_csv.cc index 118ea3454..e786c5b5e 100644 --- a/src/loggers/alert_csv.cc +++ b/src/loggers/alert_csv.cc @@ -343,8 +343,8 @@ static void ff_server_pkts(const Args& a) static void ff_service(const Args& a) { const char* svc = "unknown"; - if ( a.pkt->flow and a.pkt->flow->service ) - svc = a.pkt->flow->service; + if ( a.pkt->flow and a.pkt->flow->has_service() ) + svc = a.pkt->flow->service->c_str(); TextLog_Puts(csv_log, svc); } diff --git a/src/loggers/alert_json.cc b/src/loggers/alert_json.cc index 5bc6422f3..bb1d13599 100644 --- a/src/loggers/alert_json.cc +++ b/src/loggers/alert_json.cc @@ -473,8 +473,8 @@ static bool ff_service(const Args& a) { const char* svc = "unknown"; - if ( a.pkt->flow and a.pkt->flow->service ) - svc = a.pkt->flow->service; + if ( a.pkt->flow and a.pkt->flow->has_service() ) + svc = a.pkt->flow->service->c_str(); print_label(a, "service"); TextLog_Quote(json_log, svc); diff --git a/src/main.cc b/src/main.cc index 2bd7ecdb4..561212dd3 100644 --- a/src/main.cc +++ b/src/main.cc @@ -310,6 +310,8 @@ void snort::main_broadcast_command(AnalyzerCommand* ac, ControlConn* ctrlcon) ac = get_command(ac, ctrlcon); debug_logf(snort_trace, TRACE_MAIN, nullptr, "Broadcasting %s command\n", ac->stringify()); + if (ac->need_update_reload_id()) + SnortConfig::get_main_conf()->update_reload_id(); for (unsigned idx = 0; idx < max_pigs; ++idx) { @@ -422,7 +424,6 @@ int main_reload_config(lua_State* L) } PluginManager::reload_so_plugins_cleanup(sc); - sc->update_reload_id(); SnortConfig::set_conf(sc); TraceApi::thread_reinit(sc->trace_config); proc_stats.conf_reloads++; @@ -469,7 +470,6 @@ int main_reload_policy(lua_State* L) send_response(ctrlcon, "== reload failed\n"); return 0; } - sc->update_reload_id(); SnortConfig::set_conf(sc); proc_stats.policy_reloads++; @@ -515,7 +515,6 @@ int main_reload_module(lua_State* L) send_response(ctrlcon, "== reload failed\n"); return 0; } - sc->update_reload_id(); SnortConfig::set_conf(sc); proc_stats.policy_reloads++; diff --git a/src/main/CMakeLists.txt b/src/main/CMakeLists.txt index 97746692d..ebcd871db 100644 --- a/src/main/CMakeLists.txt +++ b/src/main/CMakeLists.txt @@ -4,6 +4,7 @@ set (INCLUDES analyzer_command.h policy.h reload_tracker.h + reload_tuner.h snort.h snort_config.h snort_debug.h @@ -22,6 +23,8 @@ if ( ENABLE_SHELL ) set ( SHELL_SOURCES ac_shell_cmd.h ac_shell_cmd.cc) endif ( ENABLE_SHELL ) +add_subdirectory(test) + add_library (main OBJECT analyzer.cc analyzer_command.cc @@ -67,3 +70,4 @@ include_directories (${CMAKE_CURRENT_BINARY_DIR}) install (FILES ${INCLUDES} DESTINATION "${INCLUDE_INSTALL_PATH}/main" ) + diff --git a/src/main/ac_shell_cmd.cc b/src/main/ac_shell_cmd.cc index 7b44a58be..89b2ce111 100644 --- a/src/main/ac_shell_cmd.cc +++ b/src/main/ac_shell_cmd.cc @@ -27,7 +27,7 @@ #include "control/control.h" -ACShellCmd::ACShellCmd(ControlConn* ctrlcon, AnalyzerCommand* ac) : ctrlcon(ctrlcon), ac(ac) +ACShellCmd::ACShellCmd(ControlConn* conn, AnalyzerCommand* ac) : AnalyzerCommand(conn), ac(ac) { assert(ac); diff --git a/src/main/ac_shell_cmd.h b/src/main/ac_shell_cmd.h index 47cf77a9c..4a4a6166c 100644 --- a/src/main/ac_shell_cmd.h +++ b/src/main/ac_shell_cmd.h @@ -34,11 +34,12 @@ public: ACShellCmd() = delete; ACShellCmd(ControlConn*, snort::AnalyzerCommand*); bool execute(Analyzer&, void**) override; + bool need_update_reload_id() const override + { return ac->need_update_reload_id(); } const char* stringify() override { return ac->stringify(); } ~ACShellCmd() override; private: - ControlConn* ctrlcon; snort::AnalyzerCommand* ac; }; diff --git a/src/main/analyzer.cc b/src/main/analyzer.cc index 3feca6d4a..46cb06fc0 100644 --- a/src/main/analyzer.cc +++ b/src/main/analyzer.cc @@ -169,6 +169,7 @@ static void process_daq_sof_eof_msg(DAQ_Msg_h msg, DAQ_Verdict& verdict) const DAQ_FlowStats_t *stats = (const DAQ_FlowStats_t*) daq_msg_get_hdr(msg); const char* key; + select_default_policy(*stats, SnortConfig::get_conf()); if (daq_msg_get_type(msg) == DAQ_MSG_TYPE_EOF) { packet_time_update(&stats->eof_timestamp); @@ -390,7 +391,7 @@ void Analyzer::process_daq_pkt_msg(DAQ_Msg_h msg, bool retry) Packet* p = switcher->get_context()->packet; p->context->wire_packet = p; p->context->packet_number = get_packet_number(); - select_default_policy(pkthdr, p->context->conf); + select_default_policy(*pkthdr, p->context->conf); DetectionEngine::reset(); sfthreshold_reset(); @@ -652,8 +653,8 @@ void Analyzer::term() DetectionEngine::idle(); InspectorManager::thread_stop(sc); - ModuleManager::accumulate(); InspectorManager::thread_term(); + ModuleManager::accumulate(); ActionManager::thread_term(); IpsManager::clear_options(sc); @@ -776,6 +777,9 @@ bool Analyzer::handle_command() return false; void* ac_state = nullptr; + if ( ac->need_update_reload_id() ) + SnortConfig::update_thread_reload_id(); + if ( ac->execute(*this, &ac_state) ) add_command_to_completed_queue(ac); else diff --git a/src/main/analyzer.h b/src/main/analyzer.h index bd6c5b0ef..2c2e354da 100644 --- a/src/main/analyzer.h +++ b/src/main/analyzer.h @@ -43,7 +43,6 @@ class Swapper; namespace snort { class AnalyzerCommand; -class ReloadResourceTuner; class SFDAQInstance; struct Packet; struct SnortConfig; diff --git a/src/main/analyzer_command.cc b/src/main/analyzer_command.cc index 726e4f8a3..2b9c48241 100644 --- a/src/main/analyzer_command.cc +++ b/src/main/analyzer_command.cc @@ -35,12 +35,36 @@ #include "analyzer.h" #include "reload_tracker.h" +#include "reload_tuner.h" #include "snort.h" #include "snort_config.h" #include "swapper.h" using namespace snort; +void AnalyzerCommand::log_message(ControlConn* ctrlcon, const char* format, va_list& ap) +{ + LogMessage(format, ap); + if (ctrlcon && !ctrlcon->is_local()) + ctrlcon->respond(format, ap); +} + +void AnalyzerCommand::log_message(ControlConn* ctrlcon, const char* format, ...) +{ + va_list args; + va_start(args, format); + log_message(ctrlcon, format, args); + va_end(args); +} + +void AnalyzerCommand::log_message(const char* format, ...) +{ + va_list args; + va_start(args, format); + log_message(ctrlcon, format, args); + va_end(args); +} + bool ACStart::execute(Analyzer& analyzer, void**) { analyzer.start(); @@ -108,9 +132,6 @@ bool ACResetStats::execute(Analyzer&, void**) ACResetStats::ACResetStats(clear_counter_type_t requested_type_l) : requested_type( requested_type_l) { } -ACSwap::ACSwap(Swapper* ps, ControlConn *ctrlcon) : ps(ps), ctrlcon(ctrlcon) -{ } - bool ACSwap::execute(Analyzer& analyzer, void** ac_state) { if (ps) @@ -180,15 +201,9 @@ ACSwap::~ACSwap() HostAttributesManager::swap_cleanup(); ReloadTracker::end(ctrlcon); - LogMessage("== reload complete\n"); - if (ctrlcon && !ctrlcon->is_local()) - ctrlcon->respond("== reload complete\n"); + log_message("== reload complete\n"); } -ACHostAttributesSwap::ACHostAttributesSwap(ControlConn *ctrlcon) - : ctrlcon(ctrlcon) -{ } - bool ACHostAttributesSwap::execute(Analyzer&, void**) { HostAttributesManager::initialize(); @@ -199,9 +214,7 @@ ACHostAttributesSwap::~ACHostAttributesSwap() { HostAttributesManager::swap_cleanup(); ReloadTracker::end(ctrlcon); - LogMessage("== reload host attributes complete\n"); - if (ctrlcon && !ctrlcon->is_local()) - ctrlcon->respond("== reload host attributes complete\n"); + log_message("== reload host attributes complete\n"); } bool ACDAQSwap::execute(Analyzer& analyzer, void**) diff --git a/src/main/analyzer_command.h b/src/main/analyzer_command.h index f2e4fa08c..41abd5dcf 100644 --- a/src/main/analyzer_command.h +++ b/src/main/analyzer_command.h @@ -20,6 +20,7 @@ #ifndef ANALYZER_COMMANDS_H #define ANALYZER_COMMANDS_H +#include #include #include "main/snort_types.h" @@ -35,14 +36,25 @@ class SFDAQInstance; class AnalyzerCommand { public: + AnalyzerCommand() : AnalyzerCommand(nullptr) + { } + explicit AnalyzerCommand(ControlConn* conn) : ctrlcon(conn) + { } virtual ~AnalyzerCommand() = default; virtual bool execute(Analyzer&, void**) = 0; + virtual bool need_update_reload_id() const + { return false; } virtual const char* stringify() = 0; unsigned get() { return ++ref_count; } unsigned put() { return --ref_count; } + SO_PUBLIC void log_message(const char* format, ...) __attribute__((format (printf, 2, 3))); + SO_PUBLIC static void log_message(ControlConn*, const char* format, ...) __attribute__((format (printf, 2, 3))); SO_PUBLIC static snort::SFDAQInstance* get_daq_instance(Analyzer& analyzer); + ControlConn* ctrlcon; + private: + static void log_message(ControlConn*, const char* format, va_list& ap); unsigned ref_count = 0; }; } @@ -50,12 +62,11 @@ private: class ACGetStats : public snort::AnalyzerCommand { public: - ACGetStats(ControlConn* conn) : ctrlcon(conn) {} + ACGetStats(ControlConn* conn) : AnalyzerCommand(conn) + { } bool execute(Analyzer&, void**) override; const char* stringify() override { return "GET_STATS"; } ~ACGetStats() override; -private: - ControlConn* ctrlcon; }; typedef enum clear_counter_type @@ -69,7 +80,7 @@ typedef enum clear_counter_type TYPE_HA } clear_counter_type_t; -// FIXIT-M Will replace this vector with an unordered map of +// FIXIT-M Will replace this vector with an unordered map of // when // will come up with more granular form of clearing module stats. static std::vector clear_counter_type_string_map @@ -145,25 +156,25 @@ class ACSwap : public snort::AnalyzerCommand { public: ACSwap() = delete; - ACSwap(Swapper* ps, ControlConn* ctrlcon); + ACSwap(Swapper* ps, ControlConn* conn) : AnalyzerCommand(conn), ps(ps) + { } bool execute(Analyzer&, void**) override; + bool need_update_reload_id() const override + { return true; } const char* stringify() override { return "SWAP"; } ~ACSwap() override; private: Swapper *ps; - ControlConn* ctrlcon; }; class ACHostAttributesSwap : public snort::AnalyzerCommand { public: - ACHostAttributesSwap(ControlConn* ctrlcon); + ACHostAttributesSwap(ControlConn* conn) : AnalyzerCommand(conn) + { } bool execute(Analyzer&, void**) override; const char* stringify() override { return "HOST_ATTRIBUTES_SWAP"; } ~ACHostAttributesSwap() override; - -private: - ControlConn* ctrlcon; }; class ACDAQSwap : public snort::AnalyzerCommand @@ -178,9 +189,9 @@ namespace snort { // from main.cc #ifdef REG_TEST -void main_unicast_command(AnalyzerCommand* ac, unsigned target, ControlConn* ctrlcon = nullptr); +void main_unicast_command(AnalyzerCommand*, unsigned target, ControlConn* = nullptr); #endif -SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand* ac, ControlConn* ctrlcon = nullptr); +SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand*, ControlConn* = nullptr); } #endif diff --git a/src/main/modules.cc b/src/main/modules.cc index 33ae97078..3fb387d60 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -1053,7 +1053,6 @@ class NetworkModule : public Module public: NetworkModule() : Module("network", network_help, network_params) { } bool set(const char*, Value&, SnortConfig*) override; - bool end(const char*, int, SnortConfig*) override; Usage get_usage() const override { return CONTEXT; } @@ -1090,16 +1089,6 @@ bool NetworkModule::set(const char*, Value& v, SnortConfig* sc) return true; } -bool NetworkModule::end(const char*, int idx, SnortConfig* sc) -{ - if (!idx) - { - NetworkPolicy* p = get_network_policy(); - sc->policy_map->set_user_network(p); - } - return true; -} - //------------------------------------------------------------------------- // inspection policy module //------------------------------------------------------------------------- @@ -1177,10 +1166,12 @@ bool InspectionModule::set(const char*, Value& v, SnortConfig* sc) return true; } -bool InspectionModule::end(const char*, int, SnortConfig* sc) +bool InspectionModule::end(const char*, int, SnortConfig*) { InspectionPolicy* p = get_inspection_policy(); - sc->policy_map->set_user_inspection(p); + NetworkPolicy* np = get_network_parse_policy(); + assert(np); + np->set_user_inspection(p); return true; } diff --git a/src/main/policy.cc b/src/main/policy.cc index 440e09746..ee19cd08d 100644 --- a/src/main/policy.cc +++ b/src/main/policy.cc @@ -30,6 +30,7 @@ #include "framework/file_policy.h" #include "framework/policy_selector.h" #include "log/messages.h" +#include "main/thread_config.h" #include "managers/inspector_manager.h" #include "parser/parse_conf.h" #include "parser/vars.h" @@ -45,8 +46,8 @@ using namespace snort; // traffic policy //------------------------------------------------------------------------- -NetworkPolicy::NetworkPolicy(PolicyId id, PolicyId default_inspection_id) - : policy_id(id), default_inspection_policy_id(default_inspection_id) +NetworkPolicy::NetworkPolicy(PolicyId id, PolicyId default_ips_id) + : policy_id(id), default_ips_policy_id(default_ips_id) { init(nullptr, nullptr); } NetworkPolicy::NetworkPolicy(NetworkPolicy* other_network_policy, const char* exclude_name) @@ -55,7 +56,24 @@ NetworkPolicy::NetworkPolicy(NetworkPolicy* other_network_policy, const char* ex NetworkPolicy::~NetworkPolicy() { FilePolicyBase::delete_file_policy(file_policy); + if (cloned) + { + if ( !inspection_policy.empty() ) + { + InspectionPolicy* default_policy = inspection_policy[0]; + default_policy->cloned = true; + delete default_policy; + } + } + else + { + for ( auto p : inspection_policy ) + delete p; + } + InspectorManager::delete_policy(this, cloned); + + inspection_policy.clear(); } void NetworkPolicy::init(NetworkPolicy* other_network_policy, const char* exclude_name) @@ -63,10 +81,27 @@ void NetworkPolicy::init(NetworkPolicy* other_network_policy, const char* exclud file_policy = new FilePolicy; if (other_network_policy) { + for ( unsigned i = 0; i < (other_network_policy->inspection_policy.size()); i++) + { + if ( i == 0 ) + inspection_policy.emplace_back( + new InspectionPolicy(other_network_policy->inspection_policy[i])); + else + inspection_policy.emplace_back(other_network_policy->inspection_policy[i]); + } + user_inspection = other_network_policy->user_inspection; + // Fix references to inspection_policy[0] + for ( auto p : other_network_policy->user_inspection ) + { + if ( p.second == other_network_policy->inspection_policy[0] ) + user_inspection[p.first] = inspection_policy[0]; + } + + dbus.clone(other_network_policy->dbus, exclude_name); policy_id = other_network_policy->policy_id; - default_inspection_policy_id = other_network_policy->default_inspection_policy_id; user_policy_id = other_network_policy->user_policy_id; + default_ips_policy_id = other_network_policy->default_ips_policy_id; min_ttl = other_network_policy->min_ttl; new_ttl = other_network_policy->new_ttl; @@ -86,6 +121,12 @@ FilePolicy* NetworkPolicy::get_file_policy() const void NetworkPolicy::add_file_policy_rule(FileRule& file_rule) { file_policy->insert_file_rule(file_rule); } +InspectionPolicy* NetworkPolicy::get_user_inspection_policy(unsigned user_id) +{ + auto it = user_inspection.find(user_id); + return it == user_inspection.end() ? nullptr : it->second; +} + //------------------------------------------------------------------------- // inspection policy //------------------------------------------------------------------------- @@ -176,6 +217,8 @@ IpsPolicy::~IpsPolicy() PolicyMap::PolicyMap(PolicyMap* other_map, const char* exclude_name) { + unsigned max = ThreadConfig::get_instance_max(); + inspector_tinit_complete = new bool[max]{}; if ( other_map ) clone(other_map, exclude_name); else @@ -183,13 +226,14 @@ PolicyMap::PolicyMap(PolicyMap* other_map, const char* exclude_name) file_id = InspectorManager::create_single_instance_inspector_policy(); flow_tracking = InspectorManager::create_single_instance_inspector_policy(); global_inspector_policy = InspectorManager::create_global_inspector_policy(); - add_shell(new Shell(nullptr, true), true); + add_shell(new Shell(nullptr, true), nullptr); empty_ips_policy = new IpsPolicy(ips_policy.size()); ips_policy.push_back(empty_ips_policy); } set_network_policy(network_policy[0]); - set_inspection_policy(inspection_policy[0]); + set_network_parse_policy(network_policy[0]); + set_inspection_policy(network_policy[0]->get_inspection_policy(0)); set_ips_policy(ips_policy[0]); } @@ -197,17 +241,10 @@ PolicyMap::~PolicyMap() { if ( cloned ) { - if ( !inspection_policy.empty() ) + for (auto np: network_policy) { - InspectionPolicy* default_policy = inspection_policy[0]; - default_policy->cloned = true; - delete default_policy; - } - if ( !network_policy.empty() ) - { - NetworkPolicy* default_policy = network_policy[0]; - default_policy->cloned = true; - delete default_policy; + np->cloned = true; + delete np; } } else @@ -215,9 +252,6 @@ PolicyMap::~PolicyMap() for ( auto p : shells ) delete p; - for ( auto p : inspection_policy ) - delete p; - for ( auto p : ips_policy ) delete p; @@ -232,10 +266,20 @@ PolicyMap::~PolicyMap() InspectorManager::destroy_global_inspector_policy(global_inspector_policy, cloned); shells.clear(); - inspection_policy.clear(); ips_policy.clear(); network_policy.clear(); shell_map.clear(); + delete[] inspector_tinit_complete; +} + +bool PolicyMap::setup_network_policies() +{ + for (auto* np : network_policy) + { + if (!set_user_network(np)) + return false; + } + return true; } void PolicyMap::clone(PolicyMap *other_map, const char* exclude_name) @@ -248,63 +292,45 @@ void PolicyMap::clone(PolicyMap *other_map, const char* exclude_name) ips_policy = other_map->ips_policy; empty_ips_policy = other_map->empty_ips_policy; - for ( unsigned i = 0; i < (other_map->network_policy.size()); i++) - { - if ( i == 0 ) - network_policy.emplace_back(new NetworkPolicy(other_map->network_policy[i], - exclude_name)); - else - network_policy.emplace_back(other_map->network_policy[i]); - } - - for ( unsigned i = 0; i < (other_map->inspection_policy.size()); i++) - { - if ( i == 0 ) - inspection_policy.emplace_back(new InspectionPolicy(other_map->inspection_policy[i])); - else - inspection_policy.emplace_back(other_map->inspection_policy[i]); - } + for ( unsigned i = 0; i < other_map->network_policy.size(); i++) + network_policy.emplace_back(new NetworkPolicy(other_map->network_policy[i], + i ? nullptr : exclude_name)); shell_map = other_map->shell_map; // Fix references to network_policy[0] and inspection_policy[0] for ( auto p : other_map->shell_map ) { - if ( p.second->network == other_map->network_policy[0] ) - shell_map[p.first]->network = network_policy[0]; - if ( p.second->inspection == other_map->inspection_policy[0] ) - shell_map[p.first] = std::make_shared(inspection_policy[0], p.second->ips, - p.second->network); - } - - user_network = other_map->user_network; - // Fix references to network_policy[0] - for ( auto p : other_map->user_network ) - { - if ( p.second == other_map->network_policy[0] ) - user_network[p.first] = network_policy[0]; - } - - user_inspection = other_map->user_inspection; - // Fix references to inspection_policy[0] - for ( auto p : other_map->user_inspection ) - { - if ( p.second == other_map->inspection_policy[0] ) - user_inspection[p.first] = inspection_policy[0]; + for ( unsigned idx = 0; idx < other_map->network_policy.size(); ++idx) + { + if ( p.second->network == other_map->network_policy[idx] ) + { + shell_map[p.first]->network = network_policy[idx]; + shell_map[p.first]->network_parse = network_policy[idx]; + } + if ( p.second->inspection == other_map->network_policy[idx]->inspection_policy[0] ) + shell_map[p.first] = + std::make_shared( + other_map->network_policy[idx]->inspection_policy[0], p.second->ips, + p.second->network, p.second->network); + } } + //user_network = other_map->user_network; user_ips = other_map->user_ips; } InspectionPolicy* PolicyMap::add_inspection_shell(Shell* sh) { - unsigned idx = inspection_policy.size(); - InspectionPolicy* p = new InspectionPolicy(idx); + NetworkPolicy* np = get_network_parse_policy(); + assert(np); + unsigned idx = np->inspection_policy_count(); + InspectionPolicy* ip = new InspectionPolicy(idx); shells.push_back(sh); - inspection_policy.push_back(p); - shell_map[sh] = std::make_shared(p, nullptr, nullptr); + np->inspection_policy.push_back(ip); + shell_map[sh] = std::make_shared(ip, nullptr, nullptr, np); - return p; + return ip; } IpsPolicy* PolicyMap::add_ips_shell(Shell* sh) @@ -314,25 +340,31 @@ IpsPolicy* PolicyMap::add_ips_shell(Shell* sh) shells.push_back(sh); ips_policy.push_back(p); - shell_map[sh] = std::make_shared(nullptr, p, nullptr); + shell_map[sh] = std::make_shared(nullptr, p, nullptr, get_network_parse_policy()); return p; } -std::shared_ptr PolicyMap::add_shell(Shell* sh, bool include_network) +std::shared_ptr PolicyMap::add_shell(Shell* sh, NetworkPolicy* np_in) { shells.push_back(sh); - inspection_policy.push_back(new InspectionPolicy(inspection_policy.size())); - InspectionPolicy* ip = inspection_policy.back(); - NetworkPolicy* new_network_policy = nullptr; - if (include_network) + IpsPolicy* ips = new IpsPolicy(ips_policy.size()); + ips_policy.push_back(ips); + NetworkPolicy* np; + if (!np_in) + { + np_in = np = new NetworkPolicy(network_policy.size(), ips->policy_id); + network_policy.push_back(np); + } + else { - new_network_policy = new NetworkPolicy(network_policy.size(), ip->policy_id); - network_policy.push_back(new_network_policy); + np = np_in; + np_in = nullptr; } - ips_policy.push_back(new IpsPolicy(ips_policy.size())); - return shell_map[sh] = std::make_shared(ip, - ips_policy.back(), new_network_policy); + InspectionPolicy* ip = new InspectionPolicy(np->inspection_policy_count()); + np->inspection_policy.push_back(ip); + return shell_map[sh] = + std::make_shared(ip, ips, np_in, np); } std::shared_ptr PolicyMap::get_policies(Shell* sh) @@ -342,18 +374,39 @@ std::shared_ptr PolicyMap::get_policies(Shell* sh) return pt == shell_map.end() ? nullptr : pt->second; } +NetworkPolicy* PolicyMap::get_user_network(unsigned user_id) +{ + auto it = user_network.find(user_id); + NetworkPolicy* np = (it == user_network.end()) ? nullptr : it->second; + return np; +} + +bool PolicyMap::set_user_network(NetworkPolicy* p) +{ + NetworkPolicy* current_np = get_user_network(p->user_policy_id); + if (current_np && p != current_np) + return false; + user_network[p->user_policy_id] = p; + return true; +} + + //------------------------------------------------------------------------- // policy nav //------------------------------------------------------------------------- -static THREAD_LOCAL NetworkPolicy* s_traffic_policy = nullptr; +static THREAD_LOCAL NetworkPolicy* s_network_policy = nullptr; +static THREAD_LOCAL NetworkPolicy* s_network_parse_policy = nullptr; static THREAD_LOCAL InspectionPolicy* s_inspection_policy = nullptr; static THREAD_LOCAL IpsPolicy* s_detection_policy = nullptr; namespace snort { NetworkPolicy* get_network_policy() -{ return s_traffic_policy; } +{ return s_network_policy; } + +NetworkPolicy* get_network_parse_policy() +{ return s_network_parse_policy; } InspectionPolicy* get_inspection_policy() { return s_inspection_policy; } @@ -361,8 +414,11 @@ InspectionPolicy* get_inspection_policy() IpsPolicy* get_ips_policy() { return s_detection_policy; } +void set_network_parse_policy(NetworkPolicy* p) +{ s_network_parse_policy = p; } + void set_network_policy(NetworkPolicy* p) -{ s_traffic_policy = p; } +{ s_network_policy = p; } void set_inspection_policy(InspectionPolicy* p) { s_inspection_policy = p; } @@ -370,54 +426,56 @@ void set_inspection_policy(InspectionPolicy* p) void set_ips_policy(IpsPolicy* p) { s_detection_policy = p; } -InspectionPolicy* get_user_inspection_policy(const SnortConfig* sc, unsigned policy_id) +InspectionPolicy* get_user_inspection_policy(unsigned policy_id) { - return sc->policy_map->get_user_inspection(policy_id); + NetworkPolicy* np = get_network_policy(); + assert(np); + return np->get_user_inspection_policy(policy_id); } NetworkPolicy* get_default_network_policy(const SnortConfig* sc) { return sc->policy_map->get_network_policy(0); } -InspectionPolicy* get_default_inspection_policy(const SnortConfig* sc) -{ - return - sc->policy_map->get_inspection_policy(get_network_policy()->default_inspection_policy_id); -} - IpsPolicy* get_ips_policy(const SnortConfig* sc, unsigned i) { return sc && i < sc->policy_map->ips_policy_count() ? sc->policy_map->get_ips_policy(i) : nullptr; } -IpsPolicy* get_user_ips_policy(const SnortConfig* sc, unsigned policy_id) +IpsPolicy* get_default_ips_policy(const snort::SnortConfig* sc) { - return sc->policy_map->get_user_ips(policy_id); + NetworkPolicy* np = get_network_policy(); + assert(np); + return np->get_default_ips_policy(sc); } +IpsPolicy* get_user_ips_policy(const SnortConfig* sc, unsigned policy_id) +{ return sc->policy_map->get_user_ips(policy_id); } + IpsPolicy* get_empty_ips_policy(const SnortConfig* sc) -{ - return sc->policy_map->get_empty_ips(); -} +{ return sc->policy_map->get_empty_ips(); } } // namespace snort -void set_network_policy(const SnortConfig* sc, unsigned i) +void set_network_policy(unsigned i) { - PolicyMap* pm = sc->policy_map; + PolicyMap* pm = SnortConfig::get_conf()->policy_map; if ( i < pm->network_policy_count() ) set_network_policy(pm->get_network_policy(i)); } -void set_inspection_policy(const SnortConfig* sc, unsigned i) +void set_inspection_policy(unsigned i) { - PolicyMap* pm = sc->policy_map; - - if ( i < pm->inspection_policy_count() ) - set_inspection_policy(pm->get_inspection_policy(i)); + NetworkPolicy* np = get_network_policy(); + if (np) + { + InspectionPolicy* ip = np->get_inspection_policy(i); + if (ip) + set_inspection_policy(ip); + } } -void set_ips_policy(const SnortConfig* sc, unsigned i) +void set_ips_policy(const snort::SnortConfig* sc, unsigned i) { PolicyMap* pm = sc->policy_map; @@ -441,20 +499,24 @@ void set_policies(const SnortConfig* sc, Shell* sh) void set_default_policy(const SnortConfig* sc) { - set_network_policy(sc->policy_map->get_network_policy(0)); - set_inspection_policy(sc->policy_map->get_inspection_policy(0)); - set_ips_policy(sc->policy_map->get_ips_policy(0)); + NetworkPolicy* np = get_default_network_policy(sc); + set_network_policy(np); + set_inspection_policy(np->get_inspection_policy(0)); + set_ips_policy(get_ips_policy(sc, 0)); } -void select_default_policy(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc) +void select_default_policy(const _daq_pkt_hdr& pkthdr, const SnortConfig* sc) { PolicySelector* ps = sc->policy_map->get_policy_selector(); if (!ps || !ps->select_default_policies(pkthdr, sc)) - { - set_network_policy(sc->policy_map->get_network_policy(0)); - set_inspection_policy(sc->policy_map->get_inspection_policy(0)); - set_ips_policy(sc->policy_map->get_ips_policy(0)); - } + set_default_policy(sc); +} + +void select_default_policy(const _daq_flow_stats& stats, const snort::SnortConfig* sc) +{ + PolicySelector* ps = sc->policy_map->get_policy_selector(); + if (!ps || !ps->select_default_policies(stats, sc)) + set_default_policy(sc); } bool only_inspection_policy() diff --git a/src/main/policy.h b/src/main/policy.h index ebdf7ed7b..5956e2ff7 100644 --- a/src/main/policy.h +++ b/src/main/policy.h @@ -49,6 +49,7 @@ class PolicySelector; struct SnortConfig; } +struct _daq_flow_stats; struct _daq_pkt_hdr; struct PortTable; struct vartable_t; @@ -67,6 +68,50 @@ enum PolicyMode // FIXIT-L split into separate headers +//------------------------------------------------------------------------- +// navigator stuff +//------------------------------------------------------------------------- + +struct InspectionPolicy; +struct IpsPolicy; +struct NetworkPolicy; +class Shell; + +namespace snort +{ +SO_PUBLIC NetworkPolicy* get_network_policy(); +NetworkPolicy* get_network_parse_policy(); +SO_PUBLIC InspectionPolicy* get_inspection_policy(); +SO_PUBLIC IpsPolicy* get_ips_policy(); + +SO_PUBLIC void set_network_policy(NetworkPolicy*); +void set_network_parse_policy(NetworkPolicy*); +SO_PUBLIC void set_inspection_policy(InspectionPolicy*); +SO_PUBLIC void set_ips_policy(IpsPolicy*); + +SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*); +// Based on currently set network policy +SO_PUBLIC InspectionPolicy* get_user_inspection_policy(unsigned policy_id); + +SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0); +// Based on currently set network policy +SO_PUBLIC IpsPolicy* get_default_ips_policy(const snort::SnortConfig*); +SO_PUBLIC IpsPolicy* get_user_ips_policy(const snort::SnortConfig*, unsigned policy_id); +SO_PUBLIC IpsPolicy* get_empty_ips_policy(const snort::SnortConfig*); +} + +void set_network_policy(unsigned = 0); +void set_inspection_policy(unsigned = 0); +void set_ips_policy(const snort::SnortConfig*, unsigned = 0); + +void set_policies(const snort::SnortConfig*, Shell*); +void set_default_policy(const snort::SnortConfig*); +void select_default_policy(const _daq_pkt_hdr&, const snort::SnortConfig*); +void select_default_policy(const _daq_flow_stats&, const snort::SnortConfig*); + +bool only_inspection_policy(); +bool only_ips_policy(); + //------------------------------------------------------------------------- // traffic stuff //------------------------------------------------------------------------- @@ -86,19 +131,59 @@ enum DecodeEventFlag DECODE_EVENT_FLAG__DEFAULT = 0x00000001 }; -// Snort ac-split creates the nap (network analysis policy) -// Snort++ breaks the nap into network and inspection +//------------------------------------------------------------------------- +// inspection stuff +//------------------------------------------------------------------------- + +struct InspectionPolicy +{ +public: + InspectionPolicy(PolicyId = 0); + InspectionPolicy(InspectionPolicy* old_inspection_policy); + ~InspectionPolicy(); + + void configure(); + +public: + PolicyId policy_id = 0; + PolicyMode policy_mode = POLICY_MODE__MAX; + uint32_t user_policy_id = 0; + uuid_t uuid{}; + + struct FrameworkPolicy* framework_policy; + snort::DataBus dbus; + bool cloned; + +private: + void init(InspectionPolicy* old_inspection_policy); +}; + +//------------------------------------------------------------------------- +// Network stuff +//------------------------------------------------------------------------- class FilePolicy; class FileRule; +struct IpsPolicy; struct NetworkPolicy { public: - NetworkPolicy(PolicyId = 0, PolicyId default_inspection_id = 0); + NetworkPolicy(PolicyId = 0, PolicyId default_ips_id = 0); NetworkPolicy(NetworkPolicy*, const char*); ~NetworkPolicy(); + InspectionPolicy* get_inspection_policy(unsigned i = 0) + { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; } + unsigned inspection_policy_count() + { return inspection_policy.size(); } + InspectionPolicy* get_user_inspection_policy(unsigned user_id); + void set_user_inspection(InspectionPolicy* p) + { user_inspection[p->user_policy_id] = p; } + + IpsPolicy* get_default_ips_policy(const snort::SnortConfig* sc) + { return snort::get_ips_policy(sc, default_ips_policy_id); } + void add_file_policy_rule(FileRule& file_rule); snort::FilePolicyBase* get_base_file_policy() const; FilePolicy* get_file_policy() const; @@ -125,9 +210,12 @@ public: struct TrafficPolicy* traffic_policy; snort::DataBus dbus; + std::vector inspection_policy; + std::unordered_map user_inspection; + PolicyId policy_id = 0; uint32_t user_policy_id = 0; - PolicyId default_inspection_policy_id = 0; + PolicyId default_ips_policy_id = 0; // minimum possible (allows all but errors to pass by default) uint8_t min_ttl = 1; @@ -142,33 +230,6 @@ private: void init(NetworkPolicy*, const char*); }; -//------------------------------------------------------------------------- -// inspection stuff -//------------------------------------------------------------------------- - -struct InspectionPolicy -{ -public: - InspectionPolicy(PolicyId = 0); - InspectionPolicy(InspectionPolicy* old_inspection_policy); - ~InspectionPolicy(); - - void configure(); - -public: - PolicyId policy_id = 0; - PolicyMode policy_mode = POLICY_MODE__MAX; - uint32_t user_policy_id = 0; - uuid_t uuid{}; - - struct FrameworkPolicy* framework_policy; - snort::DataBus dbus; - bool cloned; - -private: - void init(InspectionPolicy* old_inspection_policy); -}; - //------------------------------------------------------------------------- // detection stuff //------------------------------------------------------------------------- @@ -221,16 +282,17 @@ public: // binding stuff //------------------------------------------------------------------------- -class Shell; - struct PolicyTuple { InspectionPolicy* inspection = nullptr; IpsPolicy* ips = nullptr; NetworkPolicy* network = nullptr; + NetworkPolicy* network_parse = nullptr; - PolicyTuple(InspectionPolicy* ins_pol, IpsPolicy* ips_pol, NetworkPolicy* net_pol) : - inspection(ins_pol), ips(ips_pol), network(net_pol) { } + PolicyTuple(InspectionPolicy* ins_pol, IpsPolicy* ips_pol, NetworkPolicy* net_pol, + NetworkPolicy* net_parse) : + inspection(ins_pol), ips(ips_pol), network(net_pol), network_parse(net_parse) + { } }; struct GlobalInspectorPolicy; @@ -244,33 +306,18 @@ public: InspectionPolicy* add_inspection_shell(Shell*); IpsPolicy* add_ips_shell(Shell*); - std::shared_ptr add_shell(Shell*, bool include_network); + std::shared_ptr add_shell(Shell*, NetworkPolicy*); std::shared_ptr get_policies(Shell* sh); - void clone(PolicyMap *old_map, const char* exclude_name); Shell* get_shell(unsigned i = 0) { return i < shells.size() ? shells[i] : nullptr; } - void set_user_network(NetworkPolicy* p) - { user_network[p->user_policy_id] = p; } - - void set_user_inspection(InspectionPolicy* p) - { user_inspection[p->user_policy_id] = p; } + bool setup_network_policies(); void set_user_ips(IpsPolicy* p) { user_ips[p->user_policy_id] = p; } - NetworkPolicy* get_user_network(unsigned user_id) - { - auto it = user_network.find(user_id); - return it == user_network.end() ? nullptr : it->second; - } - - InspectionPolicy* get_user_inspection(unsigned user_id) - { - auto it = user_inspection.find(user_id); - return it == user_inspection.end() ? nullptr : it->second; - } + NetworkPolicy* get_user_network(unsigned user_id); IpsPolicy* get_user_ips(unsigned user_id) { @@ -280,24 +327,15 @@ public: NetworkPolicy* get_network_policy(unsigned i = 0) { return i < network_policy.size() ? network_policy[i] : nullptr; } - - InspectionPolicy* get_inspection_policy(unsigned i = 0) - { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; } - - IpsPolicy* get_ips_policy(unsigned i = 0) - { return i < ips_policy.size() ? ips_policy[i] : nullptr; } - - IpsPolicy* get_empty_ips() - { return empty_ips_policy; } - unsigned network_policy_count() { return network_policy.size(); } - unsigned inspection_policy_count() - { return inspection_policy.size(); } - + IpsPolicy* get_ips_policy(unsigned i = 0) + { return i < ips_policy.size() ? ips_policy[i] : nullptr; } unsigned ips_policy_count() { return ips_policy.size(); } + IpsPolicy* get_empty_ips() + { return empty_ips_policy; } unsigned shells_count() { return shells.size(); } @@ -328,17 +366,23 @@ public: return (it == std::end(shell_map)) ? nullptr : it->first; } + bool get_inspector_tinit_complete(unsigned instance_id) const + { return inspector_tinit_complete[instance_id]; } + + void set_inspector_tinit_complete(unsigned instance_id, bool val) + { inspector_tinit_complete[instance_id] = val; } + private: + void clone(PolicyMap *old_map, const char* exclude_name); + bool set_user_network(NetworkPolicy* p); + std::vector shells; - std::vector inspection_policy; - std::vector ips_policy; std::vector network_policy; - + std::vector ips_policy; IpsPolicy* empty_ips_policy; std::unordered_map> shell_map; std::unordered_map user_network; - std::unordered_map user_inspection; std::unordered_map user_ips; snort::PolicySelector* selector = nullptr; @@ -346,44 +390,9 @@ private: SingleInstanceInspectorPolicy* flow_tracking; GlobalInspectorPolicy* global_inspector_policy; + bool* inspector_tinit_complete; bool cloned = false; }; -//------------------------------------------------------------------------- -// navigator stuff -//------------------------------------------------------------------------- - -// FIXIT-L may be inlined at some point; on lockdown for now -// FIXIT-L SO_PUBLIC required because SnortConfig::inline_mode(), etc. uses the function -namespace snort -{ -SO_PUBLIC NetworkPolicy* get_network_policy(); -SO_PUBLIC InspectionPolicy* get_inspection_policy(); -SO_PUBLIC IpsPolicy* get_ips_policy(); - -SO_PUBLIC void set_network_policy(NetworkPolicy*); -SO_PUBLIC void set_inspection_policy(InspectionPolicy*); -SO_PUBLIC void set_ips_policy(IpsPolicy*); - -SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*); -SO_PUBLIC InspectionPolicy* get_user_inspection_policy(const snort::SnortConfig*, unsigned policy_id); -SO_PUBLIC InspectionPolicy* get_default_inspection_policy(const snort::SnortConfig*); - -SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0); -SO_PUBLIC IpsPolicy* get_user_ips_policy(const snort::SnortConfig*, unsigned policy_id); -SO_PUBLIC IpsPolicy* get_empty_ips_policy(const snort::SnortConfig*); -} - -void set_network_policy(const snort::SnortConfig*, unsigned = 0); -void set_inspection_policy(const snort::SnortConfig*, unsigned = 0); -void set_ips_policy(const snort::SnortConfig*, unsigned = 0); - -void set_policies(const snort::SnortConfig*, Shell*); -void set_default_policy(const snort::SnortConfig*); -void select_default_policy(const _daq_pkt_hdr*, const snort::SnortConfig*); - -bool only_inspection_policy(); -bool only_ips_policy(); - #endif diff --git a/src/main/reload_tuner.h b/src/main/reload_tuner.h new file mode 100644 index 000000000..163547bd4 --- /dev/null +++ b/src/main/reload_tuner.h @@ -0,0 +1,73 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- + +#ifndef RELOAD_TUNER_H +#define RELOAD_TUNER_H + +namespace snort +{ + +class ReloadResourceTuner +{ +public: + static const unsigned RELOAD_MAX_WORK_PER_PACKET = 3; + // be aggressive when idle as analyzer gets chance once in every second only due to daq timeout + static const unsigned RELOAD_MAX_WORK_WHEN_IDLE = 32767; + + virtual ~ReloadResourceTuner() = default; + + // returns true if resource tuning required, false otherwise + virtual bool tinit() = 0; + + // each of these returns true if resource tuning is complete, false otherwise + virtual bool tune_packet_context() = 0; + virtual bool tune_idle_context() = 0; + +protected: + ReloadResourceTuner() = default; + + unsigned max_work = RELOAD_MAX_WORK_PER_PACKET; + unsigned max_work_idle = RELOAD_MAX_WORK_WHEN_IDLE; +}; + +class ReloadSwapper : public ReloadResourceTuner +{ +public: + virtual ~ReloadSwapper() override = default; + + // each of these returns true if resource tuning is complete, false otherwise + bool tune_packet_context() override + { return true; } + bool tune_idle_context() override + { return true; } + + bool tinit() override + { + tswap(); + return false; + } + + virtual void tswap() = 0; + +protected: + ReloadSwapper() = default; +}; + +} + +#endif diff --git a/src/main/shell.cc b/src/main/shell.cc index 483912eee..4807f5b85 100644 --- a/src/main/shell.cc +++ b/src/main/shell.cc @@ -479,9 +479,10 @@ bool Shell::configure(SnortConfig* sc, bool is_root) set_default_policy(sc); else { + set_network_policy(pt->network); + set_network_parse_policy(pt->network_parse); set_inspection_policy(pt->inspection); set_ips_policy(pt->ips); - set_network_policy(pt->network); } if (!sc->tweaks.empty()) @@ -544,6 +545,7 @@ bool Shell::configure(SnortConfig* sc, bool is_root) current_shells.pop(); + set_network_parse_policy(nullptr); set_default_policy(sc); ModuleManager::set_config(nullptr); loaded = true; @@ -563,6 +565,7 @@ void Shell::install(const char* name, const luaL_Reg* reg) void Shell::execute(const char* cmd, string& rsp) { + set_default_policy(SnortConfig::get_conf()); int err = 0; Lua::ManageStack ms(lua); diff --git a/src/main/snort.cc b/src/main/snort.cc index 8f3b4791c..3d4b2a51c 100644 --- a/src/main/snort.cc +++ b/src/main/snort.cc @@ -159,6 +159,9 @@ void Snort::init(int argc, char** argv) /* Set the global snort_conf that will be used during run time */ SnortConfig::set_conf(sc); + if (!sc->policy_map->setup_network_policies()) + ParseError("Network policy user ids must be unique\n"); + // This call must be immediately after "SnortConfig::set_conf(sc)" // since the first trace call may happen somewhere after this point TraceApi::thread_init(sc->trace_config); @@ -206,6 +209,8 @@ void Snort::init(int argc, char** argv) else if ( SnortConfig::log_verbose() ) InspectorManager::print_config(sc); + InspectorManager::global_init(); + InspectorManager::prepare_inspectors(sc); InspectorManager::prepare_controls(sc); // Must be after InspectorManager::configure() @@ -481,7 +486,8 @@ SnortConfig* Snort::get_reload_config(const char* fname, const char* plugin_path return nullptr; } - InspectorManager::tear_down_removed_inspectors(old, sc); + InspectorManager::reconcile_inspectors(old, sc); + InspectorManager::prepare_inspectors(sc); InspectorManager::prepare_controls(sc); FileService::verify_reload(sc); @@ -551,6 +557,7 @@ SnortConfig* Snort::get_updated_policy( reset_parse_errors(); SnortConfig* sc = new SnortConfig(other_conf, iname); + sc->global_dbus->clone(*other_conf->global_dbus, iname); if ( fname ) { @@ -601,6 +608,8 @@ SnortConfig* Snort::get_updated_policy( return nullptr; } + InspectorManager::reconcile_inspectors(other_conf, sc, true); + InspectorManager::prepare_inspectors(sc); InspectorManager::prepare_controls(sc); other_conf->cloned = true; @@ -614,6 +623,7 @@ SnortConfig* Snort::get_updated_module(SnortConfig* other_conf, const char* name reloading = true; SnortConfig* sc = new SnortConfig(other_conf, name); + sc->global_dbus->clone(*other_conf->global_dbus, name); if ( name ) { @@ -641,6 +651,8 @@ SnortConfig* Snort::get_updated_module(SnortConfig* other_conf, const char* name return nullptr; } + InspectorManager::reconcile_inspectors(other_conf, sc, true); + InspectorManager::prepare_inspectors(sc); InspectorManager::prepare_controls(sc); other_conf->cloned = true; diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index 99f2436c5..d2c7c7a8c 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -72,6 +72,7 @@ #include "utils/util_cstring.h" #include "analyzer.h" +#include "reload_tuner.h" #include "thread_config.h" using namespace snort; @@ -90,7 +91,12 @@ using namespace snort; #define OUTPUT_U2 "unified2" #define OUTPUT_FAST "alert_fast" -static THREAD_LOCAL const SnortConfig* snort_conf = nullptr; +struct ThreadSnortConfig +{ + const SnortConfig* snort_conf; + unsigned reload_id; +}; +static THREAD_LOCAL ThreadSnortConfig thread_snort_config = {}; uint32_t SnortConfig::warning_flags = 0; uint32_t SnortConfig::logging_flags = 0; @@ -136,20 +142,23 @@ static PolicyMode init_policy_mode(const SnortConfig* sc, PolicyMode mode) static void init_policies(SnortConfig* sc) { - IpsPolicy* ips_policy = nullptr; - InspectionPolicy* inspection_policy = nullptr; + for ( unsigned nidx = 0; nidx < sc->policy_map->network_policy_count(); ++nidx ) + { + NetworkPolicy* network_policy = sc->policy_map->get_network_policy(nidx); + + for ( unsigned idx = 0; idx < network_policy->inspection_policy_count(); ++idx ) + { + InspectionPolicy* inspection_policy = network_policy->get_inspection_policy(idx); + inspection_policy->policy_mode = init_policy_mode(sc, inspection_policy->policy_mode); + } + } for ( unsigned idx = 0; idx < sc->policy_map->ips_policy_count(); ++idx ) { - ips_policy = sc->policy_map->get_ips_policy(idx); + IpsPolicy* ips_policy = get_ips_policy(sc, idx); ips_policy->policy_mode = init_policy_mode(sc, ips_policy->policy_mode); } - for ( unsigned idx = 0; idx < sc->policy_map->inspection_policy_count(); ++idx ) - { - inspection_policy = sc->policy_map->get_inspection_policy(idx); - inspection_policy->policy_mode = init_policy_mode(sc, inspection_policy->policy_mode); - } } void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* protocol_reference, @@ -179,6 +188,7 @@ void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* p memory = new MemoryConfig(); policy_map = new PolicyMap; thread_config = new ThreadConfig(); + global_dbus = new DataBus(); proto_ref = new ProtocolReference(protocol_reference); so_rules = new SoRules; @@ -214,6 +224,7 @@ SnortConfig::~SnortConfig() { if ( cloned ) { + delete global_dbus; policy_map->set_cloned(true); delete policy_map; return; @@ -247,7 +258,7 @@ SnortConfig::~SnortConfig() snort_free(eth_dst); if ( fast_pattern_config && - (!snort_conf || this == snort_conf || + (!thread_snort_config.snort_conf || this == thread_snort_config.snort_conf || (fast_pattern_config->get_search_api() != get_conf()->fast_pattern_config->get_search_api())) ) { @@ -264,6 +275,7 @@ SnortConfig::~SnortConfig() delete trace_config; delete overlay_trace_config; delete ha_config; + delete global_dbus; delete profiler; delete latency; @@ -331,6 +343,7 @@ void SnortConfig::post_setup() void SnortConfig::clone(const SnortConfig* const conf) { *this = *conf; + global_dbus = new DataBus(); if (conf->homenet.get_family() != 0) memcpy(&homenet, &conf->homenet, sizeof(homenet)); @@ -467,6 +480,9 @@ bool SnortConfig::verify() const bool config_ok = false; const SnortConfig* sc = get_conf(); + if (!policy_map->setup_network_policies()) + ReloadError("Network policy user ids must be unique\n"); + if ( sc->asn1_mem != asn1_mem ) ReloadError("Changing detection.asn1_mem requires a restart.\n"); @@ -978,14 +994,20 @@ void SnortConfig::release_scratch(int id) } SnortConfig* SnortConfig::get_main_conf() -{ return const_cast(snort_conf); } +{ return const_cast(thread_snort_config.snort_conf); } const SnortConfig* SnortConfig::get_conf() -{ return snort_conf; } +{ return thread_snort_config.snort_conf; } + +unsigned SnortConfig::get_thread_reload_id() +{ return thread_snort_config.reload_id; } + +void SnortConfig::update_thread_reload_id() +{ thread_snort_config.reload_id = thread_snort_config.snort_conf->reload_id; } void SnortConfig::set_conf(const SnortConfig* sc) { - snort_conf = sc; + thread_snort_config.snort_conf = sc; if ( sc ) { @@ -995,7 +1017,7 @@ void SnortConfig::set_conf(const SnortConfig* sc) } } -void SnortConfig::register_reload_resource_tuner(ReloadResourceTuner* rrt) +void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt) { if (Snort::is_reloading()) reload_tuners.push_back(rrt); diff --git a/src/main/snort_config.h b/src/main/snort_config.h index f1eb85ecc..f2da7cab8 100644 --- a/src/main/snort_config.h +++ b/src/main/snort_config.h @@ -163,33 +163,11 @@ namespace snort { class GHash; class ProtocolReference; +class ReloadResourceTuner; class ThreadConfig; class XHash; struct ProfilerConfig; -class ReloadResourceTuner -{ -public: - static const unsigned RELOAD_MAX_WORK_PER_PACKET = 3; - // be aggressive when idle as analyzer gets chance once in every second only due to daq timeout - static const unsigned RELOAD_MAX_WORK_WHEN_IDLE = 32767; - - virtual ~ReloadResourceTuner() = default; - - // returns true if resource tuning required, false otherwise - virtual bool tinit() = 0; - - // each of these returns true if resource tuning is complete, false otherwise - virtual bool tune_packet_context() = 0; - virtual bool tune_idle_context() = 0; - -protected: - ReloadResourceTuner() = default; - - unsigned max_work = RELOAD_MAX_WORK_PER_PACKET; - unsigned max_work_idle = RELOAD_MAX_WORK_WHEN_IDLE; -}; - struct SnortConfig { private: @@ -392,6 +370,8 @@ public: PolicyMap* policy_map = nullptr; std::string tweaks; + DataBus* global_dbus = nullptr; + uint16_t tunnel_mask = 0; int16_t max_aux_ip = 16; @@ -428,11 +408,11 @@ public: bool cloned = false; Plugins* plugins = nullptr; SoRules* so_rules = nullptr; - unsigned reload_id = 0; DumpConfigType dump_config_type = DUMP_CONFIG_NONE; private: std::list reload_tuners; + unsigned reload_id = 0; public: //------------------------------------------------------ @@ -697,13 +677,16 @@ public: // runtime access to const config - especially for packet threads // prefer access via packet->context->conf SO_PUBLIC static const SnortConfig* get_conf(); + // Thread local copy of the reload_id needed for commands that cause reevaluation + SO_PUBLIC static unsigned get_thread_reload_id(); + SO_PUBLIC static void update_thread_reload_id(); // runtime access to mutable config - main thread only, and only special cases SO_PUBLIC static SnortConfig* get_main_conf(); static void set_conf(const SnortConfig*); - SO_PUBLIC void register_reload_resource_tuner(ReloadResourceTuner*); + SO_PUBLIC void register_reload_handler(ReloadResourceTuner*); static void cleanup_fatal_error(); diff --git a/src/main/test/CMakeLists.txt b/src/main/test/CMakeLists.txt index e26ac4664..7338bc281 100644 --- a/src/main/test/CMakeLists.txt +++ b/src/main/test/CMakeLists.txt @@ -1,7 +1,7 @@ if ( ENABLE_SHELL ) add_cpputest(distill_verdict_test SOURCES - stubs.h + distill_verdict_stubs.h ../analyzer.cc ../../packet_io/active.cc ) diff --git a/src/main/test/stubs.h b/src/main/test/distill_verdict_stubs.h similarity index 96% rename from src/main/test/stubs.h rename to src/main/test/distill_verdict_stubs.h index 41416d947..45c8c297e 100644 --- a/src/main/test/stubs.h +++ b/src/main/test/distill_verdict_stubs.h @@ -116,6 +116,9 @@ eth_t* eth_close(eth_t*) { return nullptr; } ssize_t eth_send(eth_t*, const void*, size_t) { return -1; } void HostAttributesManager::initialize() { } +void select_default_policy(const _daq_pkt_hdr&, const snort::SnortConfig*) { } +void select_default_policy(const _daq_flow_stats&, const snort::SnortConfig*) { } + namespace snort { static struct timeval s_packet_time = { 0, 0 }; @@ -167,6 +170,7 @@ void DetectionEngine::clear_replacement() { } void DetectionEngine::disable_all(Packet*) { } unsigned get_instance_id() { return 0; } const SnortConfig* SnortConfig::get_conf() { return nullptr; } +void SnortConfig::update_thread_reload_id() { } void PacketTracer::thread_init() { } void PacketTracer::thread_term() { } void PacketTracer::log(const char*, ...) { } @@ -203,6 +207,7 @@ void InspectorManager::thread_stop(const SnortConfig*) { } void InspectorManager::thread_reinit(const SnortConfig*) { } void InspectorManager::thread_stop_removed(const SnortConfig*) { } void ModuleManager::accumulate() { } +void ModuleManager::accumulate_module(const char*) { } void Stream::handle_timeouts(bool) { } void Stream::purge_flows() { } bool Stream::set_packet_action_to_hold(Packet*) { return false; } @@ -216,3 +221,8 @@ Flow::Flow() = default; Flow::~Flow() = default; void ThreadConfig::implement_thread_affinity(SThreadType, unsigned) { } } + +namespace memory +{ +void MemoryCap::free_space() { } +} diff --git a/src/main/test/distill_verdict_test.cc b/src/main/test/distill_verdict_test.cc index 9c426cbb5..115b5efb7 100644 --- a/src/main/test/distill_verdict_test.cc +++ b/src/main/test/distill_verdict_test.cc @@ -23,7 +23,7 @@ #include -#include "stubs.h" +#include "distill_verdict_stubs.h" #include "main/analyzer.h" #include "memory/memory_cap.h" diff --git a/src/managers/CMakeLists.txt b/src/managers/CMakeLists.txt index 3d502019e..175257456 100644 --- a/src/managers/CMakeLists.txt +++ b/src/managers/CMakeLists.txt @@ -13,6 +13,8 @@ set( MANAGERS_INCLUDES inspector_manager.h ) +add_subdirectory(test) + add_library( managers OBJECT ${LUA_INCLUDES} ${MANAGERS_INCLUDES} diff --git a/src/managers/inspector_manager.cc b/src/managers/inspector_manager.cc index 7d167dfa0..9c0307242 100644 --- a/src/managers/inspector_manager.cc +++ b/src/managers/inspector_manager.cc @@ -39,8 +39,8 @@ #include "main/snort_debug.h" #include "main/snort_module.h" #include "main/thread_config.h" -#include "search_engines/search_tool.h" #include "protocols/packet.h" +#include "search_engines/search_tool.h" #include "target_based/snort_protocols.h" #include "time/clock_defs.h" #include "time/stopwatch.h" @@ -67,18 +67,18 @@ using namespace std; // this distinction should be more precise when policy foo is ripped out of // the instances. -struct PHGlobal +struct PHObject { const InspectApi& api; bool initialized = false; // In the context of the main thread, this means that api.pinit() // has been called. In the packet thread, it means that // api.tinit() has been called. - bool instance_initialized = false; // In the packet thread, at least one instance has had + bool instance_initialized = false; //In the packet thread, at least one instance has had // tinit() called. - PHGlobal(const InspectApi& p) : api(p) { } + PHObject(const InspectApi& p) : api(p) { } - static bool comp(const PHGlobal* a, const PHGlobal* b) + static bool comp(const PHObject* a, const PHObject* b) { return ( a->api.type < b->api.type ); } }; @@ -106,6 +106,17 @@ enum ReloadType RELOAD_TYPE_MAX }; +typedef vector PHObjectList; +typedef vector PHTSObjectLists; +struct ThreadSpecificHandlers +{ + explicit ThreadSpecificHandlers(unsigned max) + { olists.resize(max); } + ~ThreadSpecificHandlers() = default; + PHTSObjectLists olists; + unsigned ref_count = 1; +}; + struct PHInstance { PHClass& pp_class; @@ -138,8 +149,8 @@ struct PHInstance ReloadType get_reload_type() { return reload_type; } - void tinit(); - void tterm(); + void tinit(PHObjectList* handlers); + void tterm(PHObjectList* handlers); }; PHInstance::PHInstance(PHClass& p, SnortConfig* sc, Module* mod) : pp_class(p) @@ -163,9 +174,18 @@ PHInstance::~PHInstance() handler->rem_global_ref(); } -typedef vector PHGlobalList; +typedef vector PHGlobalList; typedef vector PHClassList; typedef vector PHInstanceList; +struct PHRemovedInstance +{ + PHRemovedInstance(PHInstance* i, PHTSObjectLists& handlers) + : instance(i), handlers(handlers) + { } + PHInstance* instance; + PHTSObjectLists& handlers; +}; +typedef vector PHRemovedInstanceList; typedef list PHList; static PHGlobalList s_handlers; @@ -173,7 +193,13 @@ static PHList s_trash; static PHList s_trash2; static bool s_sorted = false; -static THREAD_LOCAL vector* s_tl_handlers = nullptr; +static PHTSObjectLists s_tl_handlers; + +void InspectorManager::global_init() +{ + if (s_tl_handlers.size() != ThreadConfig::get_instance_max()) + s_tl_handlers.resize(ThreadConfig::get_instance_max(), nullptr); +} struct FrameworkConfig { @@ -182,11 +208,11 @@ struct FrameworkConfig struct PHVector { - PHInstance** vec; - unsigned num; + PHInstance** vec = nullptr; + unsigned num = 0; + unsigned total_num = 0; - PHVector() - { vec = nullptr; num = 0; } + PHVector() = default; ~PHVector() { if ( vec ) delete[] vec; } @@ -195,7 +221,10 @@ struct PHVector { vec = new PHInstance*[max]; } void add(PHInstance* p) - { vec[num++] = p; } + { + vec[num++] = p; + total_num = num; + } void add_control(PHInstance*); }; @@ -210,7 +239,6 @@ void PHVector::add_control(PHInstance* p) if ( strcmp(name, app_id) or !num ) add(p); - else { add(vec[0]); @@ -220,75 +248,176 @@ void PHVector::add_control(PHInstance* p) struct InspectorList { - virtual ~InspectorList() = default; + virtual ~InspectorList(); PHInstanceList ilist; // List of inspector module instances - PHInstanceList removed_ilist; // List of removed inspector module instances + PHRemovedInstanceList removed_ilist; // List of removed inspector module instances - virtual void handle_new_reenabled(SnortConfig*, bool, bool) = 0; + virtual void handle_new_reenabled(SnortConfig*, bool, bool) + { } virtual void vectorize(SnortConfig*) = 0; - void tinit(); - void tterm(); + void tinit(PHObjectList* handlers); + void tterm(PHObjectList* handlers); void tterm_removed(); - void populate_removed(SnortConfig*, InspectorList*); + void populate_removed(SnortConfig*, InspectorList* new_il, PHTSObjectLists& handlers); + void populate_removed(SnortConfig*, InspectorList* new_il, InspectorList* def_il, + PHTSObjectLists& handlers); + void populate_all_removed(SnortConfig* sc, InspectorList* def_il, + PHTSObjectLists& handlers); void clear_removed(); + void reconcile_inspectors(SnortConfig*, InspectorList* old_list, bool cloned); + void allocate_thread_storage(); }; -void InspectorList::tinit() +InspectorList::~InspectorList() +{ clear_removed(); } + +void InspectorList::tinit(PHObjectList* handlers) { for ( auto* p : ilist ) - p->tinit(); + p->tinit(handlers); } -void InspectorList::tterm() +void InspectorList::tterm(PHObjectList* handlers) { for ( auto* p : ilist ) - p->tterm(); + p->tterm(handlers); } void InspectorList::tterm_removed() { - for ( auto* p : removed_ilist ) - p->tterm(); + for ( auto& ri : removed_ilist ) + ri.instance->tterm(ri.handlers[Inspector::slot]); } static PHInstance* get_instance(InspectorList* il, const char* keyword); -void InspectorList::populate_removed(SnortConfig* sc, InspectorList* old) +void InspectorList::populate_removed(SnortConfig* sc, InspectorList* new_il, + PHTSObjectLists& handlers) { - for (auto it = old->ilist.begin(); it != old->ilist.end(); ++it) + assert(new_il); + for (auto* p : ilist) { - PHInstance* instance = get_instance(this, (*it)->name.c_str()); + PHInstance* instance = get_instance(new_il, p->name.c_str()); if (!instance) { - removed_ilist.emplace_back(*it); - (*it)->handler->add_global_ref(); - (*it)->handler->tear_down(sc); + new_il->removed_ilist.emplace_back(p, handlers); + p->handler->add_global_ref(); + p->handler->tear_down(sc); } } } +void InspectorList::populate_removed(SnortConfig* sc, InspectorList* new_il, + InspectorList* def_il, PHTSObjectLists& handlers) +{ + assert(def_il); + for (auto* p : ilist) + { + PHInstance* instance = new_il ? get_instance(new_il, p->name.c_str()) : nullptr; + if (!instance) + { + def_il->removed_ilist.emplace_back(p, handlers); + p->handler->add_global_ref(); + p->handler->tear_down(sc); + } + } +} + +void InspectorList::populate_all_removed(SnortConfig* sc, InspectorList* def_il, + PHTSObjectLists& handlers) +{ + assert(def_il); + for (auto* p : ilist) + { + def_il->removed_ilist.emplace_back(p, handlers); + p->handler->add_global_ref(); + p->handler->tear_down(sc); + } +} + void InspectorList::clear_removed() { - for ( auto* p : removed_ilist ) - p->handler->rem_global_ref(); + for ( auto& ri : removed_ilist ) + ri.instance->handler->rem_global_ref(); removed_ilist.clear(); } +void InspectorList::reconcile_inspectors(SnortConfig* sc, InspectorList* old_list, bool cloned) +{ + if (old_list) + { + for (auto* p : ilist) + { + for (auto* old_p : old_list->ilist) + { + if (old_p->name == p->name) + { + ReloadType reload_type = p->get_reload_type(); + if (!cloned || RELOAD_TYPE_NEW == reload_type + || RELOAD_TYPE_REENABLED == reload_type) + { + p->handler->copy_thread_storage(old_p->handler); + p->handler->install_reload_handler(sc); + } + break; + } + } + } + } +} + +void InspectorList::allocate_thread_storage() +{ + for (auto* p : ilist) + p->handler->allocate_thread_storage(); +} + +static PHInstance* get_instance_from_vector(const char* key, PHInstance** vec, unsigned num) +{ + for (unsigned i = 0; i < num; ++i) + { + PHInstance* ph = vec[i]; + if (ph->name == key) + return ph; + } + return nullptr; +} + struct TrafficPolicy : public InspectorList { + TrafficPolicy() = default; + ~TrafficPolicy() override; PHVector passive; PHVector packet; PHVector first; PHVector control; - void handle_new_reenabled(SnortConfig*, bool, bool) override - { } + ThreadSpecificHandlers* ts_handlers = nullptr; + void vectorize(SnortConfig*) override; + PHInstance* get_instance_by_type(const char* key, InspectorType); + + PHObjectList* get_specific_handlers(); }; +TrafficPolicy::~TrafficPolicy() +{ + if (ts_handlers) + { + assert(ts_handlers->ref_count); + --ts_handlers->ref_count; + if (!ts_handlers->ref_count) + { + for (auto* h : ts_handlers->olists) + delete h; + delete ts_handlers; + } + } +} + void TrafficPolicy::vectorize(SnortConfig*) { passive.alloc(ilist.size()); @@ -313,7 +442,7 @@ void TrafficPolicy::vectorize(SnortConfig*) break; case IT_CONTROL: - control.add_control(p); + control.add(p); break; default: @@ -325,6 +454,41 @@ void TrafficPolicy::vectorize(SnortConfig*) } } +PHObjectList* TrafficPolicy::get_specific_handlers() +{ + assert(ts_handlers); + PHObjectList* handlers = ts_handlers->olists[Inspector::slot]; + if (!handlers) + { + handlers = new PHObjectList; + ts_handlers->olists[Inspector::slot] = handlers; + } + return handlers; +} + +PHInstance* TrafficPolicy::get_instance_by_type(const char* key, InspectorType type) +{ + switch (type) + { + case IT_PASSIVE: + return get_instance_from_vector(key, passive.vec, passive.total_num); + + case IT_PACKET: + return get_instance_from_vector(key, packet.vec, packet.total_num); + + case IT_FIRST: + return get_instance_from_vector(key, first.vec, first.total_num); + + case IT_CONTROL: + return get_instance_from_vector(key, control.vec, control.total_num); + + default: + assert(false); + break; + } + return nullptr; +} + class SingleInstanceInspectorPolicy { public: @@ -332,13 +496,16 @@ public: ~SingleInstanceInspectorPolicy(); bool get_new(SnortConfig*, Module*, PHClass&, PHInstance*&); - void populate_removed(SnortConfig*, SingleInstanceInspectorPolicy*); + void populate_removed(SnortConfig*, SingleInstanceInspectorPolicy* new_instance); void clear_removed(); void configure(SnortConfig*); - void tinit(); - void tterm(); + void reconcile_inspector(SnortConfig*, SingleInstanceInspectorPolicy* old_instance, + bool cloned); + void tinit(PHObjectList* handlers); + void tterm(PHObjectList* handlers); void tterm_removed(); void print_config(SnortConfig*, const char* title); + void allocate_thread_storage(); PHInstance* instance = nullptr; PHInstance* removed_instance = nullptr; @@ -355,7 +522,7 @@ SingleInstanceInspectorPolicy::~SingleInstanceInspectorPolicy() delete instance; } - assert(nullptr == removed_instance); + clear_removed(); } bool SingleInstanceInspectorPolicy::get_new(SnortConfig*sc, Module* mod, PHClass& pc, @@ -375,13 +542,13 @@ bool SingleInstanceInspectorPolicy::get_new(SnortConfig*sc, Module* mod, PHClass } void SingleInstanceInspectorPolicy::populate_removed(SnortConfig* sc, - SingleInstanceInspectorPolicy* old) + SingleInstanceInspectorPolicy* new_instance) { - if (old->instance && !instance) + if (instance && !new_instance->instance) { - removed_instance = old->instance; - removed_instance->handler->add_global_ref(); - removed_instance->handler->tear_down(sc); + new_instance->removed_instance = instance; + instance->handler->add_global_ref(); + instance->handler->tear_down(sc); } } @@ -400,22 +567,37 @@ void SingleInstanceInspectorPolicy::configure(SnortConfig* sc) instance->handler->configure(sc); } -void SingleInstanceInspectorPolicy::tinit() +void SingleInstanceInspectorPolicy::reconcile_inspector(SnortConfig* sc, + SingleInstanceInspectorPolicy* old_instance, bool cloned) +{ + if (instance && old_instance && old_instance->instance) + { + ReloadType reload_type = instance->get_reload_type(); + if (!cloned || RELOAD_TYPE_NEW == reload_type + || RELOAD_TYPE_REENABLED == reload_type) + { + instance->handler->copy_thread_storage(old_instance->instance->handler); + instance->handler->install_reload_handler(sc); + } + } +} + +void SingleInstanceInspectorPolicy::tinit(PHObjectList* handlers) { if (instance) - instance->tinit(); + instance->tinit(handlers); } -void SingleInstanceInspectorPolicy::tterm() +void SingleInstanceInspectorPolicy::tterm(PHObjectList* handlers) { if (instance) - instance->tterm(); + instance->tterm(handlers); } void SingleInstanceInspectorPolicy::tterm_removed() { if (removed_instance) - removed_instance->tterm(); + removed_instance->tterm(s_tl_handlers[Inspector::slot]); } void SingleInstanceInspectorPolicy::print_config(SnortConfig* sc, const char* title) @@ -429,20 +611,27 @@ void SingleInstanceInspectorPolicy::print_config(SnortConfig* sc, const char* ti } } +void SingleInstanceInspectorPolicy::allocate_thread_storage() +{ + if (instance) + instance->handler->allocate_thread_storage(); +} + struct GlobalInspectorPolicy : public InspectorList { PHVector passive; PHVector probe; + PHVector control; - void handle_new_reenabled(SnortConfig*, bool, bool) override - { } void vectorize(SnortConfig*) override; + PHInstance* get_instance_by_type(const char* key, InspectorType); }; void GlobalInspectorPolicy::vectorize(SnortConfig*) { passive.alloc(ilist.size()); probe.alloc(ilist.size()); + control.alloc(ilist.size()); for ( auto* p : ilist ) { switch ( p->pp_class.api.type ) @@ -455,6 +644,10 @@ void GlobalInspectorPolicy::vectorize(SnortConfig*) probe.add(p); break; + case IT_CONTROL: + control.add_control(p); + break; + default: ParseError( "Global inspector policy (global usage) does not handle inspector %s with type %s\n", @@ -464,6 +657,26 @@ void GlobalInspectorPolicy::vectorize(SnortConfig*) } } +PHInstance* GlobalInspectorPolicy::get_instance_by_type(const char* key, InspectorType type) +{ + switch (type) + { + case IT_PASSIVE: + return get_instance_from_vector(key, passive.vec, passive.total_num); + + case IT_PROBE: + return get_instance_from_vector(key, probe.vec, probe.total_num); + + case IT_CONTROL: + return get_instance_from_vector(key, control.vec, control.total_num); + + default: + assert(false); + break; + } + return nullptr; +} + struct FrameworkPolicy : public InspectorList { PHVector passive; @@ -482,8 +695,8 @@ struct FrameworkPolicy : public InspectorList void handle_new_reenabled(SnortConfig*, bool, bool) override; void vectorize(SnortConfig*) override; void add_inspector_to_cache(PHInstance*, SnortConfig*); - void remove_inspector_from_cache(Inspector*); bool delete_inspector(SnortConfig*, const char* iname); + PHInstance* get_instance_by_type(const char* key, InspectorType); }; void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc) @@ -497,49 +710,27 @@ void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc) } } -void FrameworkPolicy::remove_inspector_from_cache(Inspector* ins) -{ - if (!ins) - return; - - for(auto i = inspector_cache_by_id.begin(); i != inspector_cache_by_id.end(); i++) - { - if (ins == i->second) - { - inspector_cache_by_id.erase(i); - break; - } - } - for(auto i = inspector_cache_by_service.begin(); i != inspector_cache_by_service.end(); i++) - { - if (ins == i->second) - { - inspector_cache_by_service.erase(i); - break; - } - } -} - static bool get_instance(InspectorList*, const char*, std::vector::iterator&); void FrameworkPolicy::handle_new_reenabled(SnortConfig* sc, bool new_ins, bool reenabled_ins) { - std::vector::iterator old_binder; - if ( get_instance(this, bind_id, old_binder) ) + if ( new_ins or reenabled_ins ) { - if ( new_ins and default_binder ) + std::vector::iterator old_binder; + if ( get_instance(this, bind_id, old_binder) ) { - if ( !((*old_binder)->is_reloaded()) ) + if ( new_ins and default_binder ) { - (*old_binder)->set_reloaded(RELOAD_TYPE_REENABLED); - ilist.erase(old_binder); + if ( !((*old_binder)->is_reloaded()) ) + { + (*old_binder)->set_reloaded(RELOAD_TYPE_REENABLED); + ilist.erase(old_binder); + } + default_binder = false; } - default_binder = false; - } - else if ( reenabled_ins and !((*old_binder)->is_reloaded()) ) - { - (*old_binder)->handler->configure(sc); + else if ( reenabled_ins and !((*old_binder)->is_reloaded()) ) + (*old_binder)->handler->configure(sc); } } } @@ -606,7 +797,6 @@ bool FrameworkPolicy::delete_inspector(SnortConfig* sc, const char* iname) if ( get_instance(this, iname, old_it) ) { (*old_it)->set_reloaded(RELOAD_TYPE_DELETED); - remove_inspector_from_cache((*old_it)->handler); ilist.erase(old_it); std::vector::iterator bind_it; if ( get_instance(this, bind_id, bind_it) ) @@ -616,6 +806,35 @@ bool FrameworkPolicy::delete_inspector(SnortConfig* sc, const char* iname) return false; } +PHInstance* FrameworkPolicy::get_instance_by_type(const char* key, InspectorType type) +{ + switch (type) + { + case IT_PASSIVE: + return get_instance_from_vector(key, passive.vec, passive.total_num); + + case IT_PACKET: + return get_instance_from_vector(key, packet.vec, packet.total_num); + + case IT_NETWORK: + return get_instance_from_vector(key, network.vec, network.total_num); + + case IT_STREAM: + return get_instance(this, key); + + case IT_SERVICE: + return get_instance_from_vector(key, service.vec, service.total_num); + + case IT_WIZARD: + return get_instance(this, key); + + default: + assert(false); + break; + } + return nullptr; +} + //------------------------------------------------------------------------- // global stuff //------------------------------------------------------------------------- @@ -641,7 +860,7 @@ const char* InspectorManager::get_inspector_type(const char* name) void InspectorManager::add_plugin(const InspectApi* api) { - PHGlobal* g = new PHGlobal(*api); + PHObject* g = new PHObject(*api); s_handlers.emplace_back(g); } @@ -834,7 +1053,6 @@ void InspectorManager::new_policy(InspectionPolicy* pi, InspectionPolicy* other_ void InspectorManager::new_policy(NetworkPolicy* pi, NetworkPolicy* other_pi) { pi->traffic_policy = new TrafficPolicy; - if ( other_pi ) pi->traffic_policy->ilist = other_pi->traffic_policy->ilist; } @@ -843,7 +1061,7 @@ void InspectorManager::delete_policy(InspectionPolicy* pi, bool cloned) { for ( auto* p : pi->framework_policy->ilist ) { - if ( cloned and !(p->is_reloaded()) ) + if ( cloned and !p->is_reloaded() ) continue; if ( p->handler->get_api()->type == IT_PASSIVE ) @@ -880,21 +1098,25 @@ void InspectorManager::update_policy(SnortConfig* sc) GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); for ( auto* p : pp->ilist ) p->set_reloaded(RELOAD_TYPE_NONE); - InspectionPolicy* ip = sc->policy_map->get_inspection_policy(); - for ( auto* p : ip->framework_policy->ilist ) - p->set_reloaded(RELOAD_TYPE_NONE); - NetworkPolicy* np = sc->policy_map->get_network_policy(); - for ( auto* p : np->traffic_policy->ilist ) - p->set_reloaded(RELOAD_TYPE_NONE); + for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx) + { + NetworkPolicy* np = sc->policy_map->get_network_policy(idx); + for ( auto* p : np->traffic_policy->ilist ) + p->set_reloaded(RELOAD_TYPE_NONE); + InspectionPolicy* ip = np->get_inspection_policy(); + for ( auto* p : ip->framework_policy->ilist ) + p->set_reloaded(RELOAD_TYPE_NONE); + } } Binder* InspectorManager::get_binder() { InspectionPolicy* pi = get_inspection_policy(); - if ( !pi || !pi->framework_policy ) + if ( !pi ) return nullptr; + assert(pi->framework_policy); return (Binder*)pi->framework_policy->binder; } @@ -906,33 +1128,75 @@ void InspectorManager::clear_removed_inspectors(SnortConfig* sc) ft->clear_removed(); GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); pp->clear_removed(); - TrafficPolicy* tp = sc->policy_map->get_network_policy()->traffic_policy; - tp->clear_removed(); - FrameworkPolicy* fp = sc->policy_map->get_inspection_policy()->framework_policy; - fp->clear_removed(); + for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx) + { + NetworkPolicy* np = sc->policy_map->get_network_policy(idx); + np->traffic_policy->clear_removed(); + FrameworkPolicy* fp = np->get_inspection_policy()->framework_policy; + fp->clear_removed(); + } } -void InspectorManager::tear_down_removed_inspectors(const SnortConfig* old, SnortConfig* sc) +void InspectorManager::reconcile_inspectors(const SnortConfig* old, SnortConfig* sc, bool cloned) { SingleInstanceInspectorPolicy* old_fid = old->policy_map->get_file_id(); SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); - fid->populate_removed(sc, old_fid); + old_fid->populate_removed(sc, fid); + fid->reconcile_inspector(sc, old_fid, cloned); SingleInstanceInspectorPolicy* old_ft = old->policy_map->get_flow_tracking(); SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); - ft->populate_removed(sc, old_ft); + old_ft->populate_removed(sc, ft); + ft->reconcile_inspector(sc, old_ft, cloned); GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); GlobalInspectorPolicy* old_pp = old->policy_map->get_global_inspector_policy(); - pp->populate_removed(sc, old_pp); + old_pp->populate_removed(sc, pp, s_tl_handlers); + pp->reconcile_inspectors(sc, old_pp, cloned); - TrafficPolicy* tp = get_default_network_policy(sc)->traffic_policy; - TrafficPolicy* old_tp = get_default_network_policy(old)->traffic_policy; - tp->populate_removed(sc, old_tp); + // Put all removed instances in the default traffic policy + TrafficPolicy* default_tp = sc->policy_map->get_network_policy(0)->traffic_policy; + for (unsigned idx = 0; idx < old->policy_map->network_policy_count(); ++idx) + { + NetworkPolicy* old_np = old->policy_map->get_network_policy(idx); + NetworkPolicy* np = sc->policy_map->get_user_network(old_np->user_policy_id); + if (np) + { + TrafficPolicy* tp = np->traffic_policy; + TrafficPolicy* old_tp = old_np->traffic_policy; + tp->ts_handlers = old_tp->ts_handlers; + ++tp->ts_handlers->ref_count; + + PHTSObjectLists& handlers = tp->ts_handlers->olists; + old_tp->populate_removed(sc, tp, default_tp, handlers); + + FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy; + FrameworkPolicy* fp = np->get_inspection_policy(0)->framework_policy; + old_fp->populate_removed(sc, fp, default_tp, handlers); + } + else + { + TrafficPolicy* old_tp = old_np->traffic_policy; + old_tp->populate_all_removed(sc, default_tp, old_tp->ts_handlers->olists); - FrameworkPolicy* fp = get_default_inspection_policy(sc)->framework_policy; - FrameworkPolicy* old_fp = get_default_inspection_policy(old)->framework_policy; - fp->populate_removed(sc, old_fp); + FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy; + old_fp->populate_all_removed(sc, default_tp, old_tp->ts_handlers->olists); + } + } + + for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx) + { + NetworkPolicy* np = sc->policy_map->get_network_policy(idx); + NetworkPolicy* old_np = old->policy_map->get_user_network(np->user_policy_id); + if (old_np) + { + np->traffic_policy->reconcile_inspectors(sc, old_np->traffic_policy, cloned); + + FrameworkPolicy* fp = np->get_inspection_policy(0)->framework_policy; + FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy; + fp->reconcile_inspectors(sc, old_fp, cloned); + } + } } Inspector* InspectorManager::get_file_inspector(const SnortConfig* sc) @@ -953,8 +1217,8 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons sc = SnortConfig::get_conf(); if ( dflt_only ) { - pi = get_default_inspection_policy(sc); ni = get_default_network_policy(sc); + pi = ni->get_inspection_policy(0); } else { @@ -962,7 +1226,7 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons ni = get_network_policy(); } - if ( pi && pi->framework_policy ) + if ( pi ) { PHInstance* p = get_instance(pi->framework_policy, key); if ( p ) @@ -992,13 +1256,58 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons return nullptr; } +Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, + InspectorType type, const SnortConfig* sc) +{ + if ( !sc ) + sc = SnortConfig::get_conf(); + + if (Module::GLOBAL == usage && IT_FILE == type) + { + SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); + assert(fid); + return (fid->instance && fid->instance->name == key) ? fid->instance->handler : nullptr; + } + else if (Module::GLOBAL == usage && IT_STREAM == type) + { + SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); + assert(ft); + return (ft->instance && ft->instance->name == key) ? ft->instance->handler : nullptr; + } + else + { + if (Module::GLOBAL == usage && IT_SERVICE != type) + { + GlobalInspectorPolicy* il = sc->policy_map->get_global_inspector_policy(); + assert(il); + PHInstance* p = il->get_instance_by_type(key, type); + return p ? p->handler : nullptr; + } + else if (Module::CONTEXT == usage) + { + TrafficPolicy* il = get_network_policy()->traffic_policy; + assert(il); + PHInstance* p = il->get_instance_by_type(key, type); + return p ? p->handler : nullptr; + } + else + { + FrameworkPolicy* il = get_inspection_policy()->framework_policy; + assert(il); + PHInstance* p = il->get_instance_by_type(key, type); + return p ? p->handler : nullptr; + } + } +} + Inspector* InspectorManager::get_service_inspector_by_service(const char* key) { InspectionPolicy* pi = get_inspection_policy(); - if ( !pi || !pi->framework_policy ) + if ( !pi ) return nullptr; + assert(pi->framework_policy); auto g = pi->framework_policy->inspector_cache_by_service.find(key); return (g != pi->framework_policy->inspector_cache_by_service.end()) ? g->second : nullptr; } @@ -1007,16 +1316,18 @@ Inspector* InspectorManager::get_service_inspector_by_id(const SnortProtocolId p { InspectionPolicy* pi = get_inspection_policy(); - if ( !pi || !pi->framework_policy ) + if ( !pi ) return nullptr; + assert(pi->framework_policy); auto g = pi->framework_policy->inspector_cache_by_id.find(protocol_id); return (g != pi->framework_policy->inspector_cache_by_id.end()) ? g->second : nullptr; } bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname) { - FrameworkPolicy* fp = sc->policy_map->get_inspection_policy()->framework_policy; + FrameworkPolicy* fp = + sc->policy_map->get_network_policy(0)->get_inspection_policy()->framework_policy; return fp->delete_inspector(sc, iname); } @@ -1078,32 +1389,33 @@ static PHClass* get_class(const char* keyword, FrameworkConfig* fc) return nullptr; } -static PHGlobal& get_thread_local_plugin(const InspectApi& api) +static PHObject& get_thread_local_plugin(const InspectApi& api, PHObjectList* handlers) { - assert(s_tl_handlers != nullptr); + assert(handlers); - for ( PHGlobal& phg : *s_tl_handlers ) + for ( PHObject& phg : *handlers ) { if ( &phg.api == &api ) return phg; } - s_tl_handlers->emplace_back(api); - return s_tl_handlers->back(); + handlers->emplace_back(api); + return handlers->back(); } -void PHInstance::tinit() +void PHInstance::tinit(PHObjectList* handlers) { - PHGlobal& phg = get_thread_local_plugin(pp_class.api); + PHObject& phg = get_thread_local_plugin(pp_class.api, handlers); if ( !phg.instance_initialized ) { - handler->tinit(); phg.instance_initialized = true; + handler->tinit(); } } -void PHInstance::tterm() +void PHInstance::tterm(PHObjectList* handlers) { - PHGlobal& phg = get_thread_local_plugin(pp_class.api); + assert(handlers); + PHObject& phg = get_thread_local_plugin(pp_class.api, handlers); if ( phg.instance_initialized ) { handler->tterm(); @@ -1113,73 +1425,90 @@ void PHInstance::tterm() void InspectorManager::thread_init(const SnortConfig* sc) { + SnortConfig::update_thread_reload_id(); Inspector::slot = get_instance_id(); // Initial build out of this thread's configured plugin registry - s_tl_handlers = new vector; + PHObjectList* g_handlers = new PHObjectList; + s_tl_handlers[Inspector::slot] = g_handlers; for ( auto* p : sc->framework_config->clist ) { - PHGlobal& phg = get_thread_local_plugin(p->api); + PHObject& phg = get_thread_local_plugin(p->api, g_handlers); if (phg.api.tinit) phg.api.tinit(); phg.initialized = true; } - // pin->tinit() only called for default policy - set_default_policy(sc); SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); - fid->tinit(); + fid->tinit(g_handlers); SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); - ft->tinit(); + ft->tinit(g_handlers); GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); - pp->tinit(); - - InspectionPolicy* pi = get_inspection_policy(); - if ( pi && pi->framework_policy ) - pi->framework_policy->tinit(); + pp->tinit(g_handlers); for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++) { NetworkPolicy* npi = sc->policy_map->get_network_policy(i); + PHObjectList* handlers = npi->traffic_policy->get_specific_handlers(); set_network_policy(npi); - npi->traffic_policy->tinit(); + npi->traffic_policy->tinit(handlers); + + InspectionPolicy* pi = npi->get_inspection_policy(0); + if ( pi ) + { + set_inspection_policy(pi); + assert(pi->framework_policy); + pi->framework_policy->tinit(handlers); + } } } void InspectorManager::thread_reinit(const SnortConfig* sc) { - // Update this thread's configured plugin registry with any newly configured inspectors - for ( auto* p : sc->framework_config->clist ) + SnortConfig::update_thread_reload_id(); + unsigned instance_id = get_instance_id(); + if (!sc->policy_map->get_inspector_tinit_complete(instance_id)) { - PHGlobal& phg = get_thread_local_plugin(p->api); - if (!phg.initialized) + sc->policy_map->set_inspector_tinit_complete(instance_id, true); + + // Update this thread's configured plugin registry with any newly configured inspectors + PHObjectList* g_handlers = s_tl_handlers[Inspector::slot]; + for ( auto* p : sc->framework_config->clist ) { - if (phg.api.tinit) - phg.api.tinit(); - phg.initialized = true; + PHObject& phg = get_thread_local_plugin(p->api, g_handlers); + if (!phg.initialized) + { + if (phg.api.tinit) + phg.api.tinit(); + phg.initialized = true; + } } - } - - set_default_policy(sc); - SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); - fid->tinit(); - SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); - ft->tinit(); - GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); - pp->tinit(); + SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); + fid->tinit(g_handlers); + SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); + ft->tinit(g_handlers); - for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++) - { - NetworkPolicy* npi = sc->policy_map->get_network_policy(i); - set_network_policy(npi); - npi->traffic_policy->tinit(); + GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); + pp->tinit(g_handlers); - // pin->tinit() only called for default policy - InspectionPolicy* pi = get_default_inspection_policy(sc); - if ( pi && pi->framework_policy ) - pi->framework_policy->tinit(); + for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++) + { + NetworkPolicy* npi = sc->policy_map->get_network_policy(i); + PHObjectList* handlers = npi->traffic_policy->get_specific_handlers(); + set_network_policy(npi); + npi->traffic_policy->tinit(handlers); + + // pin->tinit() only called for default policy + InspectionPolicy* pi = npi->get_inspection_policy(0); + if ( pi ) + { + set_inspection_policy(pi); + assert(pi->framework_policy); + pi->framework_policy->tinit(handlers); + } + } } } @@ -1194,65 +1523,70 @@ void InspectorManager::thread_stop_removed(const SnortConfig* sc) GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); pp->tterm_removed(); - // pin->tinit() only called for default policy NetworkPolicy* npi = get_default_network_policy(sc); if ( npi && npi->traffic_policy ) { // Call pin->tterm() for anything that has been initialized and removed npi->traffic_policy->tterm_removed(); - } - // pin->tinit() only called for default policy - InspectionPolicy* pi = get_default_inspection_policy(sc); - - if ( pi && pi->framework_policy ) - { - // Call pin->tterm() for anything that has been initialized and removed - pi->framework_policy->tterm_removed(); + // pin->tinit() only called for default policy + InspectionPolicy* pi = npi->get_inspection_policy(0); + if ( pi ) + { + assert(pi->framework_policy); + // Call pin->tterm() for anything that has been initialized and removed + pi->framework_policy->tterm_removed(); + } } } void InspectorManager::thread_stop(const SnortConfig* sc) { // If thread_init() was never called, we have nothing to do. - if ( !s_tl_handlers ) + PHObjectList* g_handlers = s_tl_handlers[Inspector::slot]; + if ( !g_handlers ) return; set_default_policy(sc); SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); - fid->tterm(); + fid->tterm(g_handlers); SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); - ft->tterm(); + ft->tterm(g_handlers); GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); - pp->tterm(); - - InspectionPolicy* pi = get_inspection_policy(); - if ( pi && pi->framework_policy ) - pi->framework_policy->tterm(); + pp->tterm(g_handlers); for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++) { NetworkPolicy* npi = sc->policy_map->get_network_policy(i); + PHObjectList* handlers = npi->traffic_policy->get_specific_handlers(); set_network_policy(npi); - npi->traffic_policy->tterm(); + npi->traffic_policy->tterm(handlers); + + InspectionPolicy* pi = npi->get_inspection_policy(0); + if ( pi ) + { + assert(pi->framework_policy); + pi->framework_policy->tterm(handlers); + } } } void InspectorManager::thread_term() { // If thread_init() was never called, we have nothing to do. - if ( !s_tl_handlers ) + PHObjectList* handlers = s_tl_handlers[Inspector::slot]; + if ( !handlers ) return; // Call tterm for every inspector plugin ever configured during the lifetime of this thread - for ( PHGlobal& phg : *s_tl_handlers ) + for ( PHObject& phg : *handlers ) { if ( phg.api.tterm && phg.initialized ) phg.api.tterm(); } - delete s_tl_handlers; - s_tl_handlers = nullptr; + delete handlers; + s_tl_handlers[Inspector::slot] = nullptr; } //------------------------------------------------------------------------- @@ -1308,13 +1642,17 @@ void InspectorManager::instantiate( } else if (Module::CONTEXT == mod->get_usage()) { - TrafficPolicy* il = get_network_policy()->traffic_policy; + NetworkPolicy* np = get_network_policy(); + assert(np); + TrafficPolicy* il = np->traffic_policy; assert(il); ppi = get_new(ppc, il, keyword, mod, sc); } else { - FrameworkPolicy* il = get_inspection_policy()->framework_policy; + InspectionPolicy* ip = get_inspection_policy(); + assert(ip); + FrameworkPolicy* il = ip->framework_policy; assert(il); ppi = get_new(ppc, il, keyword, mod, sc); } @@ -1338,7 +1676,9 @@ Inspector* InspectorManager::instantiate( if ( !ppc ) return nullptr; - auto fp = get_inspection_policy()->framework_policy; + InspectionPolicy* ip = get_inspection_policy(); + assert(ip); + auto fp = ip->framework_policy; auto ppi = get_new(ppc, fp, name, mod, sc); if ( !ppi ) @@ -1381,7 +1721,9 @@ static void instantiate_default_binder(SnortConfig* sc, FrameworkPolicy* fp) const InspectApi* api = get_plugin(bind_id); InspectorManager::instantiate(api, m, sc); - fp->binder = get_instance(fp, bind_id)->handler; + PHInstance* instance = get_instance(fp, bind_id); + assert(instance); + fp->binder = instance->handler; fp->binder->configure(sc); fp->default_binder = true; } @@ -1393,10 +1735,9 @@ static bool configure(SnortConfig* sc, InspectorList* il, bool cloned, bool& new for ( auto* p : il->ilist ) { - ReloadType reload_type = p->get_reload_type(); - if ( cloned ) { + ReloadType reload_type = p->get_reload_type(); if ( reload_type == RELOAD_TYPE_NEW ) new_ins = true; else if ( reload_type == RELOAD_TYPE_REENABLED ) @@ -1406,9 +1747,7 @@ static bool configure(SnortConfig* sc, InspectorList* il, bool cloned, bool& new } ok = p->handler->configure(sc) && ok; } - - if ( new_ins or reenabled_ins ) - il->handle_new_reenabled(sc, new_ins, reenabled_ins); + il->handle_new_reenabled(sc, new_ins, reenabled_ins); sort(il->ilist.begin(), il->ilist.end(), PHInstance::comp); il->vectorize(sc); @@ -1428,18 +1767,6 @@ Inspector* InspectorManager::acquire_file_inspector() return pi; } -Inspector* InspectorManager::acquire(const char* key, bool dflt_only) -{ - Inspector* pi = get_inspector(key, dflt_only); - - if ( !pi ) - FatalError("unconfigured inspector: '%s'.\n", key); - else - pi->add_global_ref(); - - return pi; -} - void InspectorManager::release(Inspector* pi) { assert(pi); @@ -1450,66 +1777,106 @@ bool InspectorManager::configure(SnortConfig* sc, bool cloned) { if ( !s_sorted ) { - sort(s_handlers.begin(), s_handlers.end(), PHGlobal::comp); + sort(s_handlers.begin(), s_handlers.end(), PHObject::comp); s_sorted = true; } bool ok = true; SearchTool::set_conf(sc); - bool new_ins = false; - bool reenabled_ins = false; - for ( unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx ) - { - if ( cloned and idx ) - break; - - ::set_network_policy(sc, idx); - NetworkPolicy* p = sc->policy_map->get_network_policy(idx); - ok = ::configure(sc, p->traffic_policy, cloned, new_ins, reenabled_ins) && ok; - } - ::set_network_policy(sc); - SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); fid->configure(sc); SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); ft->configure(sc); + bool new_ins = false; + bool reenabled_ins = false; + GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); ok = ::configure(sc, pp, cloned, new_ins, reenabled_ins) && ok; - ::set_inspection_policy(sc); - for ( unsigned idx = 0; idx < sc->policy_map->inspection_policy_count(); ++idx ) + for ( unsigned nidx = 0; nidx < sc->policy_map->network_policy_count(); ++nidx ) { - if ( cloned and idx ) - break; + NetworkPolicy* np = sc->policy_map->get_network_policy(nidx); + assert(np); + set_network_policy(np); + ok = ::configure(sc, np->traffic_policy, cloned, new_ins, reenabled_ins) && ok; - ::set_inspection_policy(sc, idx); - InspectionPolicy* p = sc->policy_map->get_inspection_policy(idx); - p->configure(); - ok = ::configure(sc, p->framework_policy, cloned, new_ins, reenabled_ins) && ok; + for ( unsigned idx = 0; idx < np->inspection_policy_count(); ++idx ) + { + if ( cloned and idx ) + break; + + InspectionPolicy* p = np->get_inspection_policy(idx); + assert(p); + set_inspection_policy(p); + p->configure(); + ok = ::configure(sc, p->framework_policy, cloned, new_ins, reenabled_ins) && ok; + } } - ::set_inspection_policy(sc); + NetworkPolicy* np = sc->policy_map->get_network_policy(); + assert(np); + set_network_policy(np); + set_inspection_policy(np->get_inspection_policy()); SearchTool::set_conf(nullptr); return ok; } +void InspectorManager::prepare_inspectors(SnortConfig* sc) +{ + SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id(); + fid->allocate_thread_storage(); + + SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking(); + ft->allocate_thread_storage(); + + GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); + pp->allocate_thread_storage(); + + for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx) + { + NetworkPolicy* np = sc->policy_map->get_network_policy(idx); + TrafficPolicy* tp = np->traffic_policy; + if (!tp->ts_handlers) + tp->ts_handlers = new ThreadSpecificHandlers(ThreadConfig::get_instance_max()); + tp->allocate_thread_storage(); + } +} + // remove any disabled controls while retaining order void InspectorManager::prepare_controls(SnortConfig* sc) { + GlobalInspectorPolicy* gp = sc->policy_map->get_global_inspector_policy(); + unsigned g_c = 0; + std::vector g_disabled; + for ( unsigned i = 0; i < gp->control.num; ++i ) + { + if ( !gp->control.vec[i]->handler->disable(sc) ) + gp->control.vec[g_c++] = gp->control.vec[i]; + else + g_disabled.emplace_back(gp->control.vec[i]); + } + gp->control.num = g_c; + for (auto* ph : g_disabled) + gp->control.vec[g_c++] = ph; for ( unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx ) { TrafficPolicy* tp = sc->policy_map->get_network_policy(idx)->traffic_policy; unsigned c = 0; + std::vector disabled; for ( unsigned i = 0; i < tp->control.num; ++i ) { if ( !tp->control.vec[i]->handler->disable(sc) ) tp->control.vec[c++] = tp->control.vec[i]; + else + disabled.emplace_back(tp->control.vec[i]); } tp->control.num = c; + for (auto* ph : disabled) + tp->control.vec[c++] = ph; } } @@ -1577,8 +1944,9 @@ void InspectorManager::print_config(SnortConfig* sc) } const auto inspection = policies->inspection; - if ( inspection and inspection->framework_policy ) + if ( inspection ) { + assert(inspection->framework_policy); const std::string label = "Inspection Policy : policy id " + std::to_string(inspection->user_policy_id) + " : " + shell->get_file(); @@ -1676,7 +2044,7 @@ void InspectorManager::full_inspection(Packet* p) { Flow* flow = p->flow; - if ( flow->service and flow->searching_for_service() + if ( flow->has_service() and flow->searching_for_service() and (!(p->is_cooked()) or p->is_defrag()) ) bumble(p); @@ -1765,6 +2133,9 @@ void InspectorManager::internal_execute(Packet* p) if ( p->disable_inspect ) return; + GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy(); + assert(pp); + if ( !p->flow ) { ::execute(p, tp->first.vec, tp->first.num); @@ -1777,6 +2148,7 @@ void InspectorManager::internal_execute(Packet* p) if ( p->disable_inspect ) return; + ::execute(p, pp->control.vec, pp->control.num); ::execute(p, tp->control.vec, tp->control.num); } else @@ -1784,16 +2156,17 @@ void InspectorManager::internal_execute(Packet* p) if ( !p->has_paf_payload() and p->flow->flow_state == Flow::FlowState::INSPECT ) p->flow->session->process(p); - if ( p->flow->reload_id != sc->reload_id ) + unsigned reload_id = SnortConfig::get_thread_reload_id(); + if ( p->flow->reload_id != reload_id ) { ::execute(p, tp->first.vec, tp->first.num); - p->flow->reload_id = sc->reload_id; + p->flow->reload_id = reload_id; if ( p->disable_inspect ) return; } - if ( !p->flow->service ) + if ( !p->flow->has_service() ) ::execute(p, fp->network.vec, fp->network.num); if ( p->disable_inspect ) @@ -1802,6 +2175,8 @@ void InspectorManager::internal_execute(Packet* p) if ( p->flow->full_inspection() ) full_inspection(p); + if ( !p->disable_inspect and !p->flow->is_inspection_disabled() ) + ::execute(p, pp->control.vec, pp->control.num); if ( !p->disable_inspect and !p->flow->is_inspection_disabled() ) ::execute(p, tp->control.vec, tp->control.num); } diff --git a/src/managers/inspector_manager.h b/src/managers/inspector_manager.h index 0d6a33cd1..a9b4a4913 100644 --- a/src/managers/inspector_manager.h +++ b/src/managers/inspector_manager.h @@ -26,6 +26,7 @@ #include #include "framework/inspector.h" +#include "framework/module.h" class Binder; class SingleInstanceInspectorPolicy; @@ -50,6 +51,8 @@ public: static void dump_buffers(); static void release_plugins(); + static void global_init(); + static std::vector get_apis(); static const char* get_inspector_type(const char* name); @@ -78,6 +81,8 @@ public: SO_PUBLIC static Inspector* get_file_inspector(const SnortConfig* = nullptr); SO_PUBLIC static Inspector* get_inspector( const char* key, bool dflt_only = false, const SnortConfig* = nullptr); + SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType, + const SnortConfig* = nullptr); static Inspector* get_service_inspector_by_service(const char*); static Inspector* get_service_inspector_by_id(const SnortProtocolId); @@ -85,10 +90,10 @@ public: SO_PUBLIC static Binder* get_binder(); SO_PUBLIC static Inspector* acquire_file_inspector(); - SO_PUBLIC static Inspector* acquire(const char* key, bool dflt_only = false); SO_PUBLIC static void release(Inspector*); static bool configure(SnortConfig*, bool cloned = false); + static void prepare_inspectors(SnortConfig*); static void prepare_controls(SnortConfig*); static std::string generate_inspector_label(const PHInstance*); static void print_config(SnortConfig*); @@ -105,7 +110,7 @@ public: static void clear(Packet*); static void empty_trash(); - static void tear_down_removed_inspectors(const SnortConfig*, SnortConfig*); + static void reconcile_inspectors(const SnortConfig*, SnortConfig*, bool cloned = false); static void clear_removed_inspectors(SnortConfig*); #ifdef PIGLET diff --git a/src/managers/module_manager.cc b/src/managers/module_manager.cc index 888507266..357c95f24 100644 --- a/src/managers/module_manager.cc +++ b/src/managers/module_manager.cc @@ -725,7 +725,8 @@ SO_PUBLIC bool open_table(const char* s, int idx) // FIXIT-M only basic modules, inspectors and ips actions can be reloaded at present if ( ( Snort::is_reloading() ) and h->api - and h->api->type != PT_INSPECTOR and h->api->type != PT_IPS_ACTION ) + and h->api->type != PT_INSPECTOR and h->api->type != PT_IPS_ACTION + and h->api->type != PT_POLICY_SELECTOR ) { return false; } diff --git a/src/managers/test/CMakeLists.txt b/src/managers/test/CMakeLists.txt new file mode 100644 index 000000000..cecd450f1 --- /dev/null +++ b/src/managers/test/CMakeLists.txt @@ -0,0 +1,5 @@ +add_cpputest(get_inspector_test + SOURCES + get_inspector_stubs.h + ../inspector_manager.cc +) diff --git a/src/managers/test/get_inspector_stubs.h b/src/managers/test/get_inspector_stubs.h new file mode 100644 index 000000000..1c7f44a34 --- /dev/null +++ b/src/managers/test/get_inspector_stubs.h @@ -0,0 +1,87 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2020-2022 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// stubs.h author Ron Dempster + +#include "detection/detection_engine.h" +#include "main/policy.h" +#include "main/snort.h" +#include "main/snort_config.h" +#include "main/thread_config.h" +#include "managers/inspector_manager.h" +#include "managers/module_manager.h" +#include "network_inspectors/binder/bind_module.h" +#include "search_engines/search_tool.h" +#include "trace/trace.h" +#include "trace/trace_api.h" + +THREAD_LOCAL const snort::Trace* snort_trace = nullptr; + +std::shared_ptr PolicyMap::get_policies(Shell*) { return nullptr; } +NetworkPolicy* PolicyMap::get_user_network(unsigned) { return nullptr; } +void InspectionPolicy::configure() { } +void BinderModule::add(const char*, const char*) { } +void BinderModule::add(unsigned, const char*) { } + +void set_default_policy(const snort::SnortConfig*) { } + +namespace snort +{ +unsigned THREAD_LOCAL Inspector::slot = 0; +const SnortConfig* SearchTool::conf = nullptr; +[[noreturn]] void FatalError(const char*,...) { exit(-1); } +void LogMessage(const char*, ...) { } +void LogLabel(const char*, FILE*) { } +void ParseError(const char*, ...) { } +void WarningMessage(const char*, ...) { } +void DataBus::publish(const char*, Packet*, Flow*) { } +void DetectionEngine::disable_content(Packet*) { } +unsigned SnortConfig::get_thread_reload_id() { return 1; } +void SnortConfig::update_thread_reload_id() { } +Inspector::Inspector() { } +Inspector::~Inspector() { } +bool Inspector::likes(Packet*) { return false; } +bool Inspector::get_buf(const char*, Packet*, InspectionBuffer&) { return false; } +StreamSplitter* Inspector::get_splitter(bool) { return nullptr; } +void Inspector::add_global_ref() { } +void Inspector::rem_ref() { } +void Inspector::rem_global_ref() { } +void Inspector::allocate_thread_storage() { } +void Inspector::copy_thread_storage(snort::Inspector*) { } +const char* InspectApi::get_type(InspectorType) { return ""; } +unsigned ThreadConfig::get_instance_max() { return 1; } +bool Snort::is_reloading() { return false; } +SnortProtocolId ProtocolReference::find(const char*) const { return UNKNOWN_PROTOCOL_ID; } +SnortProtocolId ProtocolReference::add(const char*) { return UNKNOWN_PROTOCOL_ID; } +uint8_t TraceApi::get_constraints_generation() { return 0; } +void TraceApi::filter(const Packet&) { } +PegCount Module::get_global_count(const char*) const { return 0; } +void Module::sum_stats(bool) { } +void Module::show_interval_stats(IndexVec&, FILE*) { } +void Module::show_stats() { } +void Module::reset_stats() { } +DataBus::DataBus() { } +DataBus::~DataBus() { } +Module* ModuleManager::get_module(const char*) { return nullptr; } + +NetworkPolicy* get_default_network_policy(const SnortConfig*) { return nullptr; } +void set_network_policy(NetworkPolicy*) { } +void set_inspection_policy(InspectionPolicy*) { } +void set_ips_policy(IpsPolicy*) { } +unsigned get_instance_id() { return 0; } +void trace_vprintf(const char*, TraceLevel, const char*, const Packet*, const char*, va_list) { } +} diff --git a/src/managers/test/get_inspector_test.cc b/src/managers/test/get_inspector_test.cc new file mode 100644 index 000000000..11470af98 --- /dev/null +++ b/src/managers/test/get_inspector_test.cc @@ -0,0 +1,327 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2020-2022 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// get_inspector_test.cc author Ron Dempster + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include + +#include "get_inspector_stubs.h" + +#include +#include +#include + +using namespace snort; + +bool Inspector::is_inactive() { return true; } + +NetworkPolicy* snort::get_network_policy() +{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); } +InspectionPolicy* snort::get_inspection_policy() +{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); } + +InspectionPolicy::InspectionPolicy(PolicyId) +{ InspectorManager::new_policy(this, nullptr); } +InspectionPolicy::~InspectionPolicy() +{ InspectorManager::delete_policy(this, false); } +NetworkPolicy::NetworkPolicy(PolicyId, PolicyId) +{ InspectorManager::new_policy(this, nullptr); } +NetworkPolicy::~NetworkPolicy() +{ + for ( auto p : inspection_policy ) + delete p; + + InspectorManager::delete_policy(this, false); + inspection_policy.clear(); +} +PolicyMap::PolicyMap(PolicyMap*, const char*) +{ + file_id = InspectorManager::create_single_instance_inspector_policy(); + flow_tracking = InspectorManager::create_single_instance_inspector_policy(); + global_inspector_policy = InspectorManager::create_global_inspector_policy(); + NetworkPolicy* np = new NetworkPolicy(network_policy.size(), 0); + network_policy.push_back(np); + InspectionPolicy* ip = new InspectionPolicy(); + np->inspection_policy.push_back(ip); +} +PolicyMap::~PolicyMap() +{ + InspectorManager::destroy_single_instance_inspector(file_id); + InspectorManager::destroy_single_instance_inspector(flow_tracking); + InspectorManager::destroy_global_inspector_policy(global_inspector_policy, false); + for ( auto p : network_policy ) + delete p; +} +SnortConfig::SnortConfig(const SnortConfig* const, const char*) +{ + policy_map = new PolicyMap(); + InspectorManager::new_config(this); +} +SnortConfig::~SnortConfig() +{ + InspectorManager::delete_config(this); + delete policy_map; +} +const SnortConfig* SnortConfig::get_conf() +{ return (const SnortConfig*)mock().getData("snort_config").getObjectPointer(); } + +Module::Module(const char* name, const char*) : name(name) +{ } + +class TestInspector : public Inspector +{ +public: + TestInspector() = default; + ~TestInspector() override = default; + void eval(Packet*) override { } +}; + +class TestModule : public Module +{ +public: + TestModule(const char* name, Module::Usage usage) : Module(name, ""), usage(usage) + { } + ~TestModule() override = default; + Usage get_usage() const override + { return usage; } + +protected: + Usage usage; +}; + + +static Inspector* test_ctor(Module* mod) +{ + std::unordered_map* mod_to_ins = + (std::unordered_map*)mock().getData("mod_to_ins").getObjectPointer(); + auto it = mod_to_ins->find(mod); + return it == mod_to_ins->end() ? nullptr : it->second; +} + +static void test_dtor(Inspector*) +{ } + +#define DECLARE_ENTRY(NAME, USAGE) \ + static TestModule NAME##_mod(#NAME, USAGE); \ + static InspectApi NAME##_api; \ + static TestInspector NAME##_ins + +DECLARE_ENTRY(binder, Module::Usage::INSPECT); + +DECLARE_ENTRY(file, Module::Usage::GLOBAL); +DECLARE_ENTRY(stream, Module::Usage::GLOBAL); + +DECLARE_ENTRY(global_passive, Module::Usage::GLOBAL); +DECLARE_ENTRY(global_probe, Module::Usage::GLOBAL); +DECLARE_ENTRY(global_control, Module::Usage::GLOBAL); + +DECLARE_ENTRY(context_passive, Module::Usage::CONTEXT); +DECLARE_ENTRY(context_packet, Module::Usage::CONTEXT); +DECLARE_ENTRY(context_first, Module::Usage::CONTEXT); +DECLARE_ENTRY(context_control, Module::Usage::CONTEXT); + +DECLARE_ENTRY(inspect_passive, Module::Usage::INSPECT); +DECLARE_ENTRY(inspect_packet, Module::Usage::INSPECT); +DECLARE_ENTRY(inspect_network, Module::Usage::INSPECT); +DECLARE_ENTRY(inspect_service, Module::Usage::INSPECT); +DECLARE_ENTRY(inspect_stream, Module::Usage::INSPECT); +DECLARE_ENTRY(inspect_wizard, Module::Usage::INSPECT); + +#define ADD_ENTRY(NAME, TYPE) \ + do { \ + NAME##_api = {}; \ + NAME##_api.base.name = NAME##_mod.get_name(); \ + NAME##_api.type = TYPE; \ + NAME##_api.ctor = test_ctor; \ + NAME##_api.dtor = test_dtor; \ + NAME##_ins.set_api(&NAME##_api); \ + InspectorManager::add_plugin(&NAME##_api); \ + } while (0) + +#define INSTANTIATE(NAME) \ + do { \ + mod_to_ins[&NAME##_mod] = &NAME##_ins; \ + InspectorManager::instantiate(&NAME##_api, &NAME##_mod, sc, NAME##_mod.get_name()); \ + } while (0) + +void setup_test_globals() +{ + ADD_ENTRY(binder, IT_PASSIVE); + + ADD_ENTRY(file, IT_FILE); + ADD_ENTRY(stream, IT_STREAM); + + ADD_ENTRY(global_passive, IT_PASSIVE); + ADD_ENTRY(global_probe, IT_PROBE); + ADD_ENTRY(global_control, IT_CONTROL); + + ADD_ENTRY(context_passive, IT_PASSIVE); + ADD_ENTRY(context_packet, IT_PACKET); + ADD_ENTRY(context_first, IT_FIRST); + ADD_ENTRY(context_control, IT_CONTROL); + + ADD_ENTRY(inspect_passive, IT_PASSIVE); + ADD_ENTRY(inspect_packet, IT_PACKET); + ADD_ENTRY(inspect_network, IT_NETWORK); + ADD_ENTRY(inspect_service, IT_SERVICE); + ADD_ENTRY(inspect_stream, IT_STREAM); + ADD_ENTRY(inspect_wizard, IT_WIZARD); +} + +TEST_GROUP(get_inspector_tests) +{ + SnortConfig* sc; + std::unordered_map mod_to_ins; + + void setup() override + { + sc = new SnortConfig; + mock().setDataObject("snort_config", "const SnortConfig", sc); + mock().setDataObject("mod_to_ins", "std::unordered_map", &mod_to_ins); + NetworkPolicy* np = sc->policy_map->get_network_policy(); + mock().setDataObject("network_policy", "NetworkPolicy", np); + InspectionPolicy* ip = np->get_inspection_policy(); + mock().setDataObject("inspection_policy", "InspectionPolicy", ip); + + INSTANTIATE(binder); + + INSTANTIATE(file); + INSTANTIATE(stream); + + INSTANTIATE(global_passive); + INSTANTIATE(global_probe); + INSTANTIATE(global_control); + + INSTANTIATE(context_passive); + INSTANTIATE(context_packet); + INSTANTIATE(context_first); + INSTANTIATE(context_control); + + INSTANTIATE(inspect_passive); + INSTANTIATE(inspect_packet); + INSTANTIATE(inspect_network); + INSTANTIATE(inspect_service); + INSTANTIATE(inspect_stream); + INSTANTIATE(inspect_wizard); + + InspectorManager::configure(sc, false); + } + + void teardown() override + { + delete sc; + InspectorManager::empty_trash(); + mod_to_ins.clear(); + mock().clear(); + } +}; + +#define THE_TEST(NAME, USAGE, TYPE) \ + do { \ + Inspector* ins = InspectorManager::get_inspector(NAME##_mod.get_name(), USAGE, TYPE); \ + CHECK_TEXT(&NAME##_ins == ins, "Did not find the " #NAME " inspector"); \ + STRCMP_EQUAL_TEXT(ins->get_name(), NAME##_mod.get_name(), "Inspector name is not " #NAME); \ + ins = InspectorManager::get_inspector("not_" #NAME, USAGE, TYPE); \ + CHECK_TEXT(nullptr == ins, "Found the not_" #NAME " inspector"); \ + } while (0) + +TEST(get_inspector_tests, file) +{ + THE_TEST(file, Module::Usage::GLOBAL, IT_FILE); +} + +TEST(get_inspector_tests, stream) +{ + THE_TEST(stream, Module::Usage::GLOBAL, IT_STREAM); +} + +TEST(get_inspector_tests, global_passive) +{ + THE_TEST(global_passive, Module::Usage::GLOBAL, IT_PASSIVE); +} + +TEST(get_inspector_tests, global_probe) +{ + THE_TEST(global_probe, Module::Usage::GLOBAL, IT_PROBE); +} + +TEST(get_inspector_tests, global_control) +{ + THE_TEST(global_control, Module::Usage::GLOBAL, IT_CONTROL); +} + +TEST(get_inspector_tests, context_passive) +{ + THE_TEST(context_passive, Module::Usage::CONTEXT, IT_PASSIVE); +} + +TEST(get_inspector_tests, context_packet) +{ + THE_TEST(context_packet, Module::Usage::CONTEXT, IT_PACKET); +} + +TEST(get_inspector_tests, context_first) +{ + THE_TEST(context_first, Module::Usage::CONTEXT, IT_FIRST); +} + +TEST(get_inspector_tests, context_control) +{ + THE_TEST(context_control, Module::Usage::CONTEXT, IT_CONTROL); +} + +TEST(get_inspector_tests, inspect_passive) +{ + THE_TEST(inspect_passive, Module::Usage::INSPECT, IT_PASSIVE); +} + +TEST(get_inspector_tests, inspect_packet) +{ + THE_TEST(inspect_packet, Module::Usage::INSPECT, IT_PACKET); +} + +TEST(get_inspector_tests, inspect_network) +{ + THE_TEST(inspect_network, Module::Usage::INSPECT, IT_NETWORK); +} + +TEST(get_inspector_tests, inspect_service) +{ + THE_TEST(inspect_service, Module::Usage::INSPECT, IT_SERVICE); +} + +TEST(get_inspector_tests, inspect_stream) +{ + THE_TEST(inspect_stream, Module::Usage::INSPECT, IT_STREAM); +} + +TEST(get_inspector_tests, inspect_wizard) +{ + THE_TEST(inspect_wizard, Module::Usage::INSPECT, IT_WIZARD); +} + +int main(int argc, char** argv) +{ + setup_test_globals(); + int r = CommandLineTestRunner::RunAllTests(argc, argv); + InspectorManager::release_plugins(); + return r; +} diff --git a/src/network_inspectors/appid/appid_config.cc b/src/network_inspectors/appid/appid_config.cc index 10eda764d..7fef03787 100644 --- a/src/network_inspectors/appid/appid_config.cc +++ b/src/network_inspectors/appid/appid_config.cc @@ -115,7 +115,7 @@ bool AppIdContext::init_appid(SnortConfig* sc, AppIdInspector& inspector) { odp_ctxt->get_client_disco_mgr().initialize(inspector); odp_ctxt->get_service_disco_mgr().initialize(inspector); - odp_thread_local_ctxt->initialize(*this, true); + odp_thread_local_ctxt->initialize(sc, *this, true); odp_ctxt->initialize(inspector); // do not reload third party on reload_config() @@ -224,12 +224,13 @@ AppId OdpContext::get_protocol_service_id(IpProtocol proto) return ip_protocol[(uint16_t)proto]; } -void OdpThreadContext::initialize(AppIdContext& ctxt, bool is_control, bool reload_odp) +void OdpThreadContext::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control, + bool reload_odp) { if (!is_control and reload_odp) - LuaDetectorManager::init_thread_manager(ctxt); + LuaDetectorManager::init_thread_manager(sc, ctxt); else - LuaDetectorManager::initialize(ctxt, is_control, reload_odp); + LuaDetectorManager::initialize(sc, ctxt, is_control, reload_odp); } OdpThreadContext::~OdpThreadContext() diff --git a/src/network_inspectors/appid/appid_config.h b/src/network_inspectors/appid/appid_config.h index adae41630..4a2383947 100644 --- a/src/network_inspectors/appid/appid_config.h +++ b/src/network_inspectors/appid/appid_config.h @@ -153,10 +153,10 @@ public: return host_port_cache.find(ip, port, proto, *this); } - bool host_port_cache_add(const snort::SfIp* ip, uint16_t port, IpProtocol proto, unsigned type, - AppId appid) + bool host_port_cache_add(const snort::SnortConfig* sc, const snort::SfIp* ip, uint16_t port, + IpProtocol proto, unsigned type, AppId appid) { - return host_port_cache.add(ip, port, proto, type, appid); + return host_port_cache.add(sc, ip, port, proto, type, appid); } AppId length_cache_find(const LengthKey& key) @@ -241,7 +241,8 @@ class OdpThreadContext { public: ~OdpThreadContext(); - void initialize(AppIdContext& ctxt, bool is_control=false, bool reload_odp=false); + void initialize(const snort::SnortConfig*, AppIdContext& ctxt, bool is_control=false, + bool reload_odp=false); void set_lua_detector_mgr(LuaDetectorManager& mgr) { diff --git a/src/network_inspectors/appid/appid_ha.cc b/src/network_inspectors/appid/appid_ha.cc index edcee9b10..0b1a8ddfa 100644 --- a/src/network_inspectors/appid/appid_ha.cc +++ b/src/network_inspectors/appid/appid_ha.cc @@ -43,13 +43,13 @@ THREAD_LOCAL AppIdHAAppsClient* AppIdHAManager::ha_apps_client = nullptr; THREAD_LOCAL AppIdHAHttpClient* AppIdHAManager::ha_http_client = nullptr; THREAD_LOCAL AppIdHATlsHostClient* AppIdHAManager::ha_tls_host_client = nullptr; -static AppIdSession* create_appid_session(Flow& flow, const FlowKey* key) +static AppIdSession* create_appid_session(Flow& flow, const FlowKey* key, + AppIdInspector& inspector) { - AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true); AppIdSession* asd = new AppIdSession(static_cast(key->ip_protocol), flow.flags.client_initiated ? &flow.client_ip : &flow.server_ip, - flow.flags.client_initiated ? flow.client_port : flow.server_port, *inspector, - inspector->get_ctxt().get_odp_ctxt(), key->addressSpaceId); + flow.flags.client_initiated ? flow.client_port : flow.server_port, inspector, + inspector.get_ctxt().get_odp_ctxt(), key->addressSpaceId); if (appidDebug->is_active()) LogMessage("AppIdDbg %s high-avail - New AppId session created in consume\n", appidDebug->get_debug_session()); @@ -67,7 +67,9 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg, if (size != sizeof(AppIdSessionHAApps)) return false; - AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true); + AppIdInspector* inspector = + static_cast( + InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type)); if (!inspector) return false; @@ -77,9 +79,9 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg, if (appidDebug->is_enabled()) { appidDebug->activate(flow, asd, inspector->get_ctxt().config.log_all_sessions); - LogMessage("AppIdDbg %s high-avail - Consuming app data - flags 0x%x, service %d, client %d, " - "payload %d, misc %d, referred %d, client_inferred_service %d, port_service %d, " - "tp_app %d, tp_payload %d\n", + LogMessage("AppIdDbg %s high-avail - Consuming app data - flags 0x%x, service %d, " + "client %d, payload %d, misc %d, referred %d, client_inferred_service %d, " + "port_service %d, tp_app %d, tp_payload %d\n", appidDebug->get_debug_session(), appHA->flags, appHA->appId[APPID_HA_APP_SERVICE], appHA->appId[APPID_HA_APP_CLIENT], appHA->appId[APPID_HA_APP_PAYLOAD], appHA->appId[APPID_HA_APP_MISC], appHA->appId[APPID_HA_APP_REFERRED], @@ -90,7 +92,7 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg, if (!asd) { - asd = create_appid_session(*flow, key); + asd = create_appid_session(*flow, key, *inspector); asd->set_service_id(appHA->appId[APPID_HA_APP_SERVICE], asd->get_odp_ctxt()); if (asd->get_service_id() == APP_ID_FTP_CONTROL) { @@ -227,7 +229,9 @@ bool AppIdHAHttpClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg, if (size != sizeof(AppIdSessionHAHttp)) return false; - AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true); + AppIdInspector* inspector = + static_cast( + InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type)); if (!inspector) return false; @@ -241,7 +245,7 @@ bool AppIdHAHttpClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg, } if (!asd) - asd = create_appid_session(*flow, key); + asd = create_appid_session(*flow, key, *inspector); AppidChangeBits change_bits; AppIdHttpSession* hsession = asd->get_http_session(); @@ -315,7 +319,9 @@ bool AppIdHATlsHostClient::consume(Flow*& flow, const FlowKey* key, HAMessage& m if (size != sizeof(AppIdSessionHATlsHost)) return false; - AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true); + AppIdInspector* inspector = + static_cast( + InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type)); if (!inspector) return false; @@ -329,7 +335,7 @@ bool AppIdHATlsHostClient::consume(Flow*& flow, const FlowKey* key, HAMessage& m } if (!asd) - asd = create_appid_session(*flow, key); + asd = create_appid_session(*flow, key, *inspector); asd->set_tls_host(appHA->tls_host); diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index 435c5af85..ac6521849 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -122,27 +122,27 @@ bool AppIdInspector::configure(SnortConfig* sc) ctxt->init_appid(sc, *this); - DataBus::subscribe_network(SIP_EVENT_TYPE_SIP_DIALOG_KEY, new SipEventHandler(*this)); + DataBus::subscribe_global(SIP_EVENT_TYPE_SIP_DIALOG_KEY, new SipEventHandler(*this), *sc); - DataBus::subscribe_network(HTTP_REQUEST_HEADER_EVENT_KEY, new HttpEventHandler( - HttpEventHandler::REQUEST_EVENT, *this)); + DataBus::subscribe_global(HTTP_REQUEST_HEADER_EVENT_KEY, new HttpEventHandler( + HttpEventHandler::REQUEST_EVENT, *this), *sc); - DataBus::subscribe_network(HTTP_RESPONSE_HEADER_EVENT_KEY, new HttpEventHandler( - HttpEventHandler::RESPONSE_EVENT, *this)); + DataBus::subscribe_global(HTTP_RESPONSE_HEADER_EVENT_KEY, new HttpEventHandler( + HttpEventHandler::RESPONSE_EVENT, *this), *sc); - DataBus::subscribe_network(HTTP2_REQUEST_BODY_EVENT_KEY, new AppIdHttp2ReqBodyEventHandler()); + DataBus::subscribe_global(HTTP2_REQUEST_BODY_EVENT_KEY, new AppIdHttp2ReqBodyEventHandler(), *sc); - DataBus::subscribe_network(DATA_DECRYPT_EVENT, new DataDecryptEventHandler()); + DataBus::subscribe_global(DATA_DECRYPT_EVENT, new DataDecryptEventHandler(), *sc); - DataBus::subscribe_network(DCERPC_EXP_SESSION_EVENT_KEY, new DceExpSsnEventHandler()); + DataBus::subscribe_global(DCERPC_EXP_SESSION_EVENT_KEY, new DceExpSsnEventHandler(), *sc); - DataBus::subscribe_network(OPPORTUNISTIC_TLS_EVENT, new AppIdOpportunisticTlsEventHandler()); + DataBus::subscribe_global(OPPORTUNISTIC_TLS_EVENT, new AppIdOpportunisticTlsEventHandler(), *sc); - DataBus::subscribe_network(EVE_PROCESS_EVENT, new AppIdEveProcessEventHandler()); + DataBus::subscribe_global(EVE_PROCESS_EVENT, new AppIdEveProcessEventHandler(), *sc); - DataBus::subscribe_network(SSH_EVENT, new SshEventHandler()); + DataBus::subscribe_global(SSH_EVENT, new SshEventHandler(), *sc); - DataBus::subscribe_network(FLOW_NO_SERVICE_EVENT, new AppIdServiceEventHandler(*this)); + DataBus::subscribe_global(FLOW_NO_SERVICE_EVENT, new AppIdServiceEventHandler(*this), *sc); return true; } @@ -163,7 +163,7 @@ void AppIdInspector::tinit() assert(!odp_thread_local_ctxt); odp_thread_local_ctxt = new OdpThreadContext(); - odp_thread_local_ctxt->initialize(*ctxt); + odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), *ctxt); AppIdServiceState::initialize(config->memcap); assert(!pkt_thread_tp_appid_ctxt); @@ -250,6 +250,7 @@ static void appid_inspector_tinit() static void appid_inspector_tterm() { TPLibHandler::tfini(); + AppIdPegCounts::sum_stats(); AppIdPegCounts::cleanup_pegs(); AppIdServiceState::clean(); delete appidDebug; diff --git a/src/network_inspectors/appid/appid_inspector.h b/src/network_inspectors/appid/appid_inspector.h index a182e38b1..d27b4a469 100644 --- a/src/network_inspectors/appid/appid_inspector.h +++ b/src/network_inspectors/appid/appid_inspector.h @@ -54,6 +54,8 @@ private: AppIdContext* ctxt = nullptr; }; +extern const snort::InspectApi appid_inspector_api; + extern THREAD_LOCAL OdpThreadContext* odp_thread_local_ctxt; extern THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt; extern THREAD_LOCAL ThirdPartyAppIdContext* pkt_thread_tp_appid_ctxt; diff --git a/src/network_inspectors/appid/appid_module.cc b/src/network_inspectors/appid/appid_module.cc index abe438ada..79c93716a 100644 --- a/src/network_inspectors/appid/appid_module.cc +++ b/src/network_inspectors/appid/appid_module.cc @@ -151,7 +151,7 @@ class ACThirdPartyAppIdContextSwap : public AnalyzerCommand public: bool execute(Analyzer&, void**) override; ACThirdPartyAppIdContextSwap(const AppIdInspector& inspector, ControlConn* conn) - : inspector(inspector), tracker_ref(conn) + : AnalyzerCommand(conn), inspector(inspector) { LogMessage("== swapping third-party configuration\n"); } @@ -160,7 +160,6 @@ public: const char* stringify() override { return "THIRD-PARTY_CONTEXT_SWAP"; } private: const AppIdInspector& inspector; - ControlConn* tracker_ref; }; bool ACThirdPartyAppIdContextSwap::execute(Analyzer&, void**) @@ -179,7 +178,7 @@ ACThirdPartyAppIdContextSwap::~ACThirdPartyAppIdContextSwap() std::string file_path = ctxt.get_tp_appid_ctxt()->get_user_config(); ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path); LogMessage("== third-party configuration swap complete\n"); - ReloadTracker::end(tracker_ref); + ReloadTracker::end(ctrlcon); } class ACThirdPartyAppIdContextUnload : public AnalyzerCommand @@ -187,13 +186,13 @@ class ACThirdPartyAppIdContextUnload : public AnalyzerCommand public: bool execute(Analyzer&, void**) override; ACThirdPartyAppIdContextUnload(const AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt, - ControlConn* ctrlcon): inspector(inspector), tp_ctxt(tp_ctxt), ctrlcon(ctrlcon) { } + ControlConn* conn): AnalyzerCommand(conn), inspector(inspector), tp_ctxt(tp_ctxt) + { } ~ACThirdPartyAppIdContextUnload() override; const char* stringify() override { return "THIRD-PARTY_CONTEXT_UNLOAD"; } private: const AppIdInspector& inspector; ThirdPartyAppIdContext* tp_ctxt = nullptr; - ControlConn* ctrlcon; }; bool ACThirdPartyAppIdContextUnload::execute(Analyzer& ac, void**) @@ -218,9 +217,7 @@ ACThirdPartyAppIdContextUnload::~ACThirdPartyAppIdContextUnload() AppIdContext& ctxt = inspector.get_ctxt(); ctxt.create_tp_appid_ctxt(); main_broadcast_command(new ACThirdPartyAppIdContextSwap(inspector, ctrlcon)); - LogMessage("== reload third-party complete\n"); - if (ctrlcon && !ctrlcon->is_local()) - ctrlcon->respond("== reload third-party complete\n"); + log_message("== reload third-party complete\n"); ReloadTracker::update(ctrlcon, "unload old third-party complete, start swapping to new configuration."); } @@ -228,14 +225,14 @@ class ACOdpContextSwap : public AnalyzerCommand { public: bool execute(Analyzer&, void**) override; - ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* ctrlcon) : - inspector(inspector), odp_ctxt(odp_ctxt), ctrlcon(ctrlcon) { } + ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* conn) : + AnalyzerCommand(conn), inspector(inspector), odp_ctxt(odp_ctxt) + { } ~ACOdpContextSwap() override; const char* stringify() override { return "ODP_CONTEXT_SWAP"; } private: const AppIdInspector& inspector; OdpContext& odp_ctxt; - ControlConn* ctrlcon; }; bool ACOdpContextSwap::execute(Analyzer&, void**) @@ -254,7 +251,7 @@ bool ACOdpContextSwap::execute(Analyzer&, void**) assert(odp_thread_local_ctxt); delete odp_thread_local_ctxt; odp_thread_local_ctxt = new OdpThreadContext; - odp_thread_local_ctxt->initialize(ctxt, false, true); + odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, false, true); return true; } @@ -270,9 +267,8 @@ ACOdpContextSwap::~ACOdpContextSwap() file_path = std::string(ctxt.config.app_detector_dir) + "/../userappid.conf"; ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path); } - LogMessage("== reload detectors complete\n"); ReloadTracker::end(ctrlcon); - ctrlcon->respond("== reload detectors complete\n"); + log_message("== reload detectors complete\n"); } static int enable_debug(lua_State* L) @@ -394,7 +390,7 @@ static int reload_detectors(lua_State* L) OdpContext& odp_ctxt = ctxt.get_odp_ctxt(); odp_ctxt.get_client_disco_mgr().initialize(*inspector); odp_ctxt.get_service_disco_mgr().initialize(*inspector); - odp_thread_local_ctxt->initialize(ctxt, true, true); + odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, true, true); odp_ctxt.initialize(*inspector); ctrlcon->respond("== swapping detectors configuration\n"); @@ -524,7 +520,7 @@ bool AppIdModule::end(const char* fqn, int, SnortConfig* sc) assert(config); if ( Snort::is_reloading() && strcmp(fqn, "appid") == 0 ) - sc->register_reload_resource_tuner(new AppIdReloadTuner(config->memcap)); + sc->register_reload_handler(new AppIdReloadTuner(config->memcap)); if ( !config->app_detector_dir ) { diff --git a/src/network_inspectors/appid/appid_module.h b/src/network_inspectors/appid/appid_module.h index ce59a8c10..5779dd3e5 100644 --- a/src/network_inspectors/appid/appid_module.h +++ b/src/network_inspectors/appid/appid_module.h @@ -28,13 +28,14 @@ #include "framework/module.h" #include "main/analyzer.h" #include "main/analyzer_command.h" -#include "main/snort_config.h" +#include "main/reload_tuner.h" #include "appid_config.h" #include "appid_peg_counts.h" namespace snort { +struct SnortConfig; class Trace; } @@ -43,6 +44,7 @@ extern THREAD_LOCAL const snort::Trace* appid_trace; #define MOD_NAME "appid" #define MOD_HELP "application and service identification" +#define MOD_USAGE snort::Module::GLOBAL class AppIdReloadTuner : public snort::ReloadResourceTuner @@ -91,7 +93,7 @@ public: void reset_stats() override; Usage get_usage() const override - { return CONTEXT; } + { return MOD_USAGE; } void sum_stats(bool) override; void show_dynamic_stats() override; diff --git a/src/network_inspectors/appid/appid_peg_counts.cc b/src/network_inspectors/appid/appid_peg_counts.cc index 2c87d2ff4..8f52b4ada 100644 --- a/src/network_inspectors/appid/appid_peg_counts.cc +++ b/src/network_inspectors/appid/appid_peg_counts.cc @@ -28,6 +28,8 @@ #include #include +#include "framework/inspector.h" +#include "main/thread_config.h" #include "utils/stats.h" using namespace snort; @@ -36,10 +38,12 @@ std::unordered_map AppIdPegCounts::appid_detector_pegs_idx; std::vector AppIdPegCounts::appid_detectors_info; THREAD_LOCAL std::vector* AppIdPegCounts::appid_peg_counts; AppIdPegCounts::AppIdDynamicPeg AppIdPegCounts::appid_dynamic_sum[SF_APPID_MAX + 1]; +AppIdPegCounts::AppIdDynamicPeg AppIdPegCounts::zeroed_peg; +PegCount AppIdPegCounts::all_zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS] = {}; void AppIdPegCounts::init_pegs() { - AppIdPegCounts::AppIdDynamicPeg zeroed_peg = AppIdPegCounts::AppIdDynamicPeg(); + assert(!appid_peg_counts); appid_peg_counts = new std::vector( appid_detectors_info.size() + 1, zeroed_peg); } @@ -47,6 +51,7 @@ void AppIdPegCounts::init_pegs() void AppIdPegCounts::cleanup_pegs() { delete appid_peg_counts; + appid_peg_counts = nullptr; } void AppIdPegCounts::cleanup_peg_info() @@ -57,22 +62,21 @@ void AppIdPegCounts::cleanup_peg_info() void AppIdPegCounts::cleanup_dynamic_sum() { - if ( !appid_peg_counts ) - return; - for ( unsigned app_num = 0; app_num < AppIdPegCounts::appid_detectors_info.size(); app_num++ ) { memset(appid_dynamic_sum[app_num].stats, 0, sizeof(PegCount) * DetectorPegs::NUM_APPID_DETECTOR_PEGS); - memset((*appid_peg_counts)[app_num].stats, 0, sizeof(PegCount) * - DetectorPegs::NUM_APPID_DETECTOR_PEGS); + if ( appid_peg_counts ) + memset((*appid_peg_counts)[app_num].stats, 0, sizeof(PegCount) * + DetectorPegs::NUM_APPID_DETECTOR_PEGS); } // reset unknown_app stats memset(appid_dynamic_sum[SF_APPID_MAX].stats, 0, sizeof(PegCount) * DetectorPegs::NUM_APPID_DETECTOR_PEGS); - memset((*appid_peg_counts)[appid_peg_counts->size() - 1].stats, 0, sizeof(PegCount) * - DetectorPegs::NUM_APPID_DETECTOR_PEGS); + if ( appid_peg_counts ) + memset((*appid_peg_counts)[appid_peg_counts->size() - 1].stats, 0, sizeof(PegCount) * + DetectorPegs::NUM_APPID_DETECTOR_PEGS); } void AppIdPegCounts::add_app_peg_info(std::string app_name, AppId app_id) diff --git a/src/network_inspectors/appid/appid_peg_counts.h b/src/network_inspectors/appid/appid_peg_counts.h index 4d80e0174..9fb148284 100644 --- a/src/network_inspectors/appid/appid_peg_counts.h +++ b/src/network_inspectors/appid/appid_peg_counts.h @@ -76,14 +76,13 @@ public: bool all_zeros() { - PegCount zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS] = { }; - return !memcmp(stats, &zeroed_peg, sizeof(stats)); + return !memcmp(stats, &all_zeroed_peg, sizeof(stats)); } void print(const char* app, char* buf, int buf_size) { snprintf(buf, buf_size, "%25.25s: " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10") - " " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10"), app, + " " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10"), app, stats[SERVICE_DETECTS], stats[CLIENT_DETECTS], stats[USER_DETECTS], stats[PAYLOAD_DETECTS], stats[MISC_DETECTS], stats[REFERRED_DETECTS]); } @@ -111,6 +110,8 @@ private: static AppIdDynamicPeg appid_dynamic_sum[SF_APPID_MAX+1]; static THREAD_LOCAL std::vector* appid_peg_counts; static uint32_t get_stats_index(AppId id); + static AppIdDynamicPeg zeroed_peg; + static PegCount all_zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS]; }; #endif diff --git a/src/network_inspectors/appid/host_port_app_cache.cc b/src/network_inspectors/appid/host_port_app_cache.cc index 616f979e2..6c69946b0 100644 --- a/src/network_inspectors/appid/host_port_app_cache.cc +++ b/src/network_inspectors/appid/host_port_app_cache.cc @@ -51,14 +51,15 @@ HostPortVal* HostPortCache::find(const SfIp* ip, uint16_t port, IpProtocol proto return nullptr; } -bool HostPortCache::add(const SfIp* ip, uint16_t port, IpProtocol proto, unsigned type, AppId - appId) +bool HostPortCache::add(const SnortConfig* sc, const SfIp* ip, uint16_t port, IpProtocol proto, + unsigned type, AppId appId) { HostPortKey hk; HostPortVal hv; hk.ip = *ip; - AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME); + AppIdInspector* inspector = + (AppIdInspector*)InspectorManager::get_inspector(MOD_NAME, false, sc); assert(inspector); const AppIdContext& ctxt = inspector->get_ctxt(); hk.port = (ctxt.get_odp_ctxt().allow_port_wildcard_host_cache)? 0 : port; diff --git a/src/network_inspectors/appid/host_port_app_cache.h b/src/network_inspectors/appid/host_port_app_cache.h index 62ae8b578..4a469372d 100644 --- a/src/network_inspectors/appid/host_port_app_cache.h +++ b/src/network_inspectors/appid/host_port_app_cache.h @@ -29,6 +29,11 @@ #include "sfip/sf_ip.h" #include "utils/cpp_macros.h" +namespace snort +{ + struct SnortConfig; +} + class OdpContext; PADDING_GUARD_BEGIN @@ -64,7 +69,8 @@ class HostPortCache { public: HostPortVal* find(const snort::SfIp*, uint16_t port, IpProtocol, const OdpContext&); - bool add(const snort::SfIp*, uint16_t port, IpProtocol, unsigned type, AppId); + bool add(const snort::SnortConfig*, const snort::SfIp*, uint16_t port, IpProtocol, + unsigned type, AppId); void dump(); ~HostPortCache() diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index c546b4d2d..6cbc2593e 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -1196,7 +1196,11 @@ static int detector_add_host_port_application(lua_State* L) if (toipprotocol(L, ++index, proto)) return 0; - if (!ud->get_odp_ctxt().host_port_cache_add(&ip_address, (uint16_t)port, proto, type, app_id)) + lua_getglobal(L, LUA_STATE_GLOBAL_SC_ID); + const SnortConfig* sc = *static_cast(lua_touserdata(L, -1)); + lua_pop(L, 1); + if (!ud->get_odp_ctxt().host_port_cache_add( + sc, &ip_address, (uint16_t)port, proto, type, app_id)) ErrorMessage("%s:Failed to backend call\n",__func__); return 0; diff --git a/src/network_inspectors/appid/lua_detector_api.h b/src/network_inspectors/appid/lua_detector_api.h index 8d849689c..8aa83efb8 100644 --- a/src/network_inspectors/appid/lua_detector_api.h +++ b/src/network_inspectors/appid/lua_detector_api.h @@ -42,6 +42,8 @@ class AppInfoTableEntry; #define DETECTOR "Detector" #define DETECTORFLOW "DetectorFlow" +#define LUA_STATE_GLOBAL_SC_ID "snort_config" + struct DetectorPackageInfo { std::string initFunctionName; diff --git a/src/network_inspectors/appid/lua_detector_module.cc b/src/network_inspectors/appid/lua_detector_module.cc index d9414813e..ebdca9532 100644 --- a/src/network_inspectors/appid/lua_detector_module.cc +++ b/src/network_inspectors/appid/lua_detector_module.cc @@ -202,7 +202,8 @@ LuaDetectorManager::~LuaDetectorManager() cb_detectors.clear(); // do not free Lua objects in cb_detectors } -void LuaDetectorManager::initialize(AppIdContext& ctxt, bool is_control, bool reload) +void LuaDetectorManager::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control, + bool reload) { LuaDetectorManager* lua_detector_mgr = new LuaDetectorManager(ctxt, is_control); odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr); @@ -226,17 +227,17 @@ void LuaDetectorManager::initialize(AppIdContext& ctxt, bool is_control, bool re } lua_detector_mgr->initialize_lua_detectors(is_control, reload); - lua_detector_mgr->activate_lua_detectors(); + lua_detector_mgr->activate_lua_detectors(sc); if (ctxt.config.list_odp_detectors) lua_detector_mgr->list_lua_detectors(); } -void LuaDetectorManager::init_thread_manager(const AppIdContext& ctxt) +void LuaDetectorManager::init_thread_manager(const SnortConfig* sc, const AppIdContext& ctxt) { LuaDetectorManager* lua_detector_mgr = lua_detector_mgr_list[get_instance_id()]; odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr); - lua_detector_mgr->activate_lua_detectors(); + lua_detector_mgr->activate_lua_detectors(sc); if (ctxt.config.list_odp_detectors) lua_detector_mgr->list_lua_detectors(); } @@ -582,7 +583,7 @@ void LuaDetectorManager::initialize_lua_detectors(bool is_control, bool reload) load_lua_detectors(path, true, is_control, reload); } -void LuaDetectorManager::activate_lua_detectors() +void LuaDetectorManager::activate_lua_detectors(const SnortConfig* sc) { uint32_t lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS, allocated_objects.size()); @@ -616,6 +617,9 @@ void LuaDetectorManager::activate_lua_detectors() /*second parameter is a table containing configuration stuff. */ lua_newtable(L); + const SnortConfig** sc_ud = static_cast(lua_newuserdata(L, sizeof(const SnortConfig*))); + *(sc_ud) = sc; + lua_setglobal(L, LUA_STATE_GLOBAL_SC_ID); if (lua_pcall(L, 2, 1, 0)) { if (init(L)) @@ -627,6 +631,7 @@ void LuaDetectorManager::activate_lua_detectors() lo = allocated_objects.erase(lo); continue; } + *(sc_ud) = nullptr; lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str()); set_lua_tracker_size(L, lua_tracker_size); diff --git a/src/network_inspectors/appid/lua_detector_module.h b/src/network_inspectors/appid/lua_detector_module.h index d380f8bee..a8d57faf5 100644 --- a/src/network_inspectors/appid/lua_detector_module.h +++ b/src/network_inspectors/appid/lua_detector_module.h @@ -36,6 +36,11 @@ #include "application_ids.h" +namespace snort +{ + struct SnortConfig; +} + class AppIdContext; class AppIdDetector; struct DetectorFlow; @@ -50,8 +55,9 @@ class LuaDetectorManager public: LuaDetectorManager(AppIdContext&, bool); ~LuaDetectorManager(); - static void initialize(AppIdContext&, bool is_control=false, bool reload=false); - static void init_thread_manager(const AppIdContext&); + static void initialize(const snort::SnortConfig*, AppIdContext&, bool is_control=false, + bool reload=false); + static void init_thread_manager(const snort::SnortConfig*, const AppIdContext&); static void clear_lua_detector_mgrs(); void set_detector_flow(DetectorFlow* df) @@ -70,7 +76,7 @@ public: private: void initialize_lua_detectors(bool is_control, bool reload = false); - void activate_lua_detectors(); + void activate_lua_detectors(const snort::SnortConfig*); void list_lua_detectors(); bool load_detector(char* detector_name, bool is_custom, bool is_control, bool reload, std::string& buf); void load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload = false); diff --git a/src/network_inspectors/binder/bind_module.cc b/src/network_inspectors/binder/bind_module.cc index 0a444b1ab..b7455990b 100644 --- a/src/network_inspectors/binder/bind_module.cc +++ b/src/network_inspectors/binder/bind_module.cc @@ -46,6 +46,7 @@ static const PegInfo bind_pegs[] = { { CountType::SUM, "raw_packets", "raw packets evaluated" }, { CountType::SUM, "new_flows", "new flows evaluated" }, + { CountType::SUM, "rebinds", "flows rebound" }, { CountType::SUM, "service_changes", "flow service changes evaluated" }, { CountType::SUM, "assistant_inspectors", "flow assistant inspector requests handled" }, { CountType::SUM, "new_standby_flows", "new HA flows evaluated" }, @@ -431,7 +432,7 @@ bool BinderModule::end(const char* fqn, int idx, SnortConfig* sc) if ( policy_type == FILE_KEY ) { Shell* sh = new Shell(policy_filename.c_str()); - auto policies = sc->policy_map->add_shell(sh, false); + auto policies = sc->policy_map->add_shell(sh, get_network_parse_policy()); binding.use.inspection_index = policies->inspection->policy_id; binding.use.ips_index = policies->ips->policy_id; } diff --git a/src/network_inspectors/binder/bind_module.h b/src/network_inspectors/binder/bind_module.h index 3b561c72a..da15cb44d 100644 --- a/src/network_inspectors/binder/bind_module.h +++ b/src/network_inspectors/binder/bind_module.h @@ -33,6 +33,7 @@ struct BindStats { PegCount raw_packets; PegCount new_flows; + PegCount rebinds; PegCount service_changes; PegCount assistant_inspectors; PegCount new_standby_flows; diff --git a/src/network_inspectors/binder/binder.cc b/src/network_inspectors/binder/binder.cc index 1fa26e9db..f4ba3d3c9 100644 --- a/src/network_inspectors/binder/binder.cc +++ b/src/network_inspectors/binder/binder.cc @@ -420,15 +420,8 @@ void Stuff::apply_action(Flow& flow) void Stuff::apply_session(Flow& flow) { - if (client) - flow.set_client(client); - else if (flow.ssn_client) - flow.clear_client(); - - if (server) - flow.set_server(server); - else if (flow.ssn_server) - flow.clear_server(); + flow.set_client(client); + flow.set_server(server); } void Stuff::apply_service(Flow& flow) @@ -490,6 +483,7 @@ public: void handle_flow_setup(Flow&, bool standby = false); void handle_flow_service_change(Flow&); void handle_assistant_gadget(const char* service, Flow&); + void handle_flow_after_reload(Flow&); private: void get_policy_bindings(Flow&, const char* service); @@ -577,6 +571,19 @@ public: } }; +class RebindFlow : public DataHandler +{ +public: + RebindFlow() : DataHandler(BIND_NAME) { } + + void handle(DataEvent&, Flow* flow) override + { + Binder* binder = InspectorManager::get_binder(); + if (binder && flow) + binder->handle_flow_after_reload(*flow); + } +}; + Binder::Binder(std::vector& bv, std::vector& pbv) { bindings = std::move(bv); @@ -623,6 +630,7 @@ bool Binder::configure(SnortConfig* sc) DataBus::subscribe(FLOW_SERVICE_CHANGE_EVENT, new FlowServiceChangeHandler()); DataBus::subscribe(STREAM_HA_NEW_FLOW_EVENT, new StreamHANewFlowHandler()); DataBus::subscribe(FLOW_ASSISTANT_GADGET_EVENT, new AssistantGadgetHandler()); + DataBus::subscribe(FLOW_STATE_RELOADED_EVENT, new RebindFlow()); return true; } @@ -721,7 +729,7 @@ void Binder::handle_flow_setup(Flow& flow, bool standby) if (flow.ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) { const SnortConfig* sc = SnortConfig::get_conf(); - flow.service = sc->proto_ref->get_name(flow.ssn_state.snort_protocol_id); + flow.service = sc->proto_ref->get_shared_name(flow.ssn_state.snort_protocol_id); } } @@ -742,7 +750,7 @@ void Binder::handle_flow_service_change(Flow& flow) Inspector* ins = nullptr; Inspector* data = nullptr; - if (flow.service) + if (flow.has_service()) { ins = find_gadget(flow, data); if (flow.gadget != ins) @@ -789,8 +797,8 @@ void Binder::handle_flow_service_change(Flow& flow) // If there is no inspector bound to this flow after the service change, see if there's at least // an associated protocol ID. - if (!ins && flow.service) - flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service); + if (!ins && flow.has_service()) + flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service->c_str()); if (flow.is_stream()) { @@ -820,9 +828,18 @@ void Binder::handle_assistant_gadget(const char* service, Flow& flow) bstats.assistant_inspectors++; } +void Binder::handle_flow_after_reload(Flow& flow) +{ + Stuff stuff; + get_bindings(flow, stuff); + stuff.apply_action(flow); + + bstats.rebinds++; + bstats.verdicts[stuff.action]++; +} + void Binder::get_policy_bindings(Flow& flow, const char* service) { - const SnortConfig* sc = SnortConfig::get_conf(); unsigned inspection_index = 0; unsigned ips_index = 0; @@ -847,13 +864,14 @@ void Binder::get_policy_bindings(Flow& flow, const char* service) if (inspection_index) { - set_inspection_policy(sc, inspection_index); + set_inspection_policy(inspection_index); if (!service) flow.inspection_policy_id = inspection_index; } if (ips_index) { + const SnortConfig* sc = SnortConfig::get_conf(); set_ips_policy(sc, ips_index); if (!service) flow.ips_policy_id = ips_index; @@ -862,7 +880,6 @@ void Binder::get_policy_bindings(Flow& flow, const char* service) void Binder::get_policy_bindings(Packet* p) { - const SnortConfig* sc = SnortConfig::get_conf(); unsigned inspection_index = 0; unsigned ips_index = 0; @@ -887,12 +904,13 @@ void Binder::get_policy_bindings(Packet* p) if (inspection_index) { - set_inspection_policy(sc, inspection_index); + set_inspection_policy(inspection_index); p->user_inspection_policy_id = get_inspection_policy()->user_policy_id; } if (ips_index) { + const SnortConfig* sc = SnortConfig::get_conf(); set_ips_policy(sc, ips_index); p->user_ips_policy_id = get_ips_policy()->user_policy_id; } @@ -965,7 +983,7 @@ void Binder::get_bindings(Packet* p, Stuff& stuff) Inspector* Binder::find_gadget(Flow& flow, Inspector*& data) { Stuff stuff; - get_bindings(flow, stuff, flow.service); + get_bindings(flow, stuff, flow.has_service() ? flow.service->c_str() : nullptr); data = stuff.data; return stuff.gadget; } diff --git a/src/network_inspectors/binder/binding.cc b/src/network_inspectors/binder/binding.cc index 282bde171..4b66064f1 100644 --- a/src/network_inspectors/binder/binding.cc +++ b/src/network_inspectors/binder/binding.cc @@ -579,10 +579,10 @@ inline bool Binding::check_service(const Flow& flow) const if (!when.has_criteria(BindWhen::Criteria::BWC_SVC)) return true; - if (!flow.service) + if (!flow.has_service()) return false; - return when.svc == flow.service; + return when.svc == flow.service->c_str(); } inline bool Binding::check_service(const char* service) const diff --git a/src/network_inspectors/perf_monitor/perf_module.cc b/src/network_inspectors/perf_monitor/perf_module.cc index 6bb3aa0b9..725668ba2 100644 --- a/src/network_inspectors/perf_monitor/perf_module.cc +++ b/src/network_inspectors/perf_monitor/perf_module.cc @@ -331,7 +331,7 @@ bool PerfMonModule::end(const char* fqn, int idx, SnortConfig* sc) { if ( Snort::is_reloading() && strcmp(fqn, "perf_monitor") == 0 ) - sc->register_reload_resource_tuner(new PerfMonReloadTuner(config->flowip_memcap)); + sc->register_reload_handler(new PerfMonReloadTuner(config->flowip_memcap)); if ( idx != 0 && strcmp(fqn, "perf_monitor.modules") == 0 ) return config->modules.back().confirm_parse(); diff --git a/src/network_inspectors/perf_monitor/perf_reload_tuner.h b/src/network_inspectors/perf_monitor/perf_reload_tuner.h index 6f3ba11aa..839330172 100644 --- a/src/network_inspectors/perf_monitor/perf_reload_tuner.h +++ b/src/network_inspectors/perf_monitor/perf_reload_tuner.h @@ -21,7 +21,7 @@ #ifndef PERF_RELOAD_TUNER_H #define PERF_RELOAD_TUNER_H -#include "main/snort_config.h" +#include "main/reload_tuner.h" class PerfMonReloadTuner : public snort::ReloadResourceTuner { diff --git a/src/network_inspectors/port_scan/ps_module.cc b/src/network_inspectors/port_scan/ps_module.cc index 63c29c459..bfb0e1077 100644 --- a/src/network_inspectors/port_scan/ps_module.cc +++ b/src/network_inspectors/port_scan/ps_module.cc @@ -25,6 +25,7 @@ #include "ps_module.h" #include "log/messages.h" #include "main/snort.h" +#include "main/snort_config.h" #include @@ -326,7 +327,7 @@ bool PortScanModule::set(const char* fqn, Value& v, SnortConfig*) bool PortScanModule::end(const char* fqn, int, SnortConfig* sc) { if ( Snort::is_reloading() && strcmp(fqn, "port_scan") == 0 ) - sc->register_reload_resource_tuner(new PortScanReloadTuner(config->memcap)); + sc->register_reload_handler(new PortScanReloadTuner(config->memcap)); return true; } diff --git a/src/network_inspectors/port_scan/ps_module.h b/src/network_inspectors/port_scan/ps_module.h index 798b436ee..4d17b3e30 100644 --- a/src/network_inspectors/port_scan/ps_module.h +++ b/src/network_inspectors/port_scan/ps_module.h @@ -22,10 +22,15 @@ #define PS_MODULE_H #include "framework/module.h" -#include "main/snort_config.h" +#include "main/reload_tuner.h" #include "ps_detect.h" #include "ps_pegs.h" +namespace snort +{ +struct SnortConfig; +} + #define PS_NAME "port_scan" #define PS_HELP "detect various ip, icmp, tcp, and udp port or protocol scans" diff --git a/src/network_inspectors/reputation/CMakeLists.txt b/src/network_inspectors/reputation/CMakeLists.txt index c4c225f15..e4d8680a8 100644 --- a/src/network_inspectors/reputation/CMakeLists.txt +++ b/src/network_inspectors/reputation/CMakeLists.txt @@ -3,6 +3,8 @@ set (REPUTATION_INCLUDES ) add_library( reputation OBJECT + reputation_commands.cc + reputation_commands.h reputation_config.h reputation_inspect.h reputation_inspect.cc diff --git a/src/network_inspectors/reputation/reputation_commands.cc b/src/network_inspectors/reputation/reputation_commands.cc new file mode 100644 index 000000000..201e17c33 --- /dev/null +++ b/src/network_inspectors/reputation/reputation_commands.cc @@ -0,0 +1,90 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2021-2022 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// reputation_commands.cc author Ron Dempster + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "reputation_commands.h" + +#include "control/control.h" +#include "log/messages.h" +#include "main/analyzer_command.h" +#include "managers/inspector_manager.h" + +#include "reputation_common.h" +#include "reputation_inspect.h" + +using namespace snort; + +class ReputationReload : public AnalyzerCommand +{ +public: + ReputationReload(ControlConn*, Reputation&); + ~ReputationReload() override; + + bool execute(Analyzer&, void**) override; + bool need_update_reload_id() const override + { return true; } + + const char* stringify() override + { return "REPUTATION_RELOAD"; } + +protected: + Reputation& ins; + ReputationData* data; +}; + +ReputationReload::ReputationReload(ControlConn* conn, Reputation& ins) + : AnalyzerCommand(conn), ins(ins) +{ + ins.add_global_ref(); + log_message(".. reputation reloading\n"); + data = ins.load_data(); +} + +ReputationReload::~ReputationReload() +{ + ins.swap_data(data); + log_message("== Reputation reload complete\n"); + ins.rem_global_ref(); +} + +bool ReputationReload::execute(Analyzer&, void**) +{ + ins.swap_thread_data(data); + return true; +} + +static int reload(lua_State* L) +{ + ControlConn* ctrlcon = ControlConn::query_from_lua(L); + Reputation* ins = static_cast(InspectorManager::get_inspector(REPUTATION_NAME)); + if (ins) + main_broadcast_command(new ReputationReload(ctrlcon, *ins), ctrlcon); + else + AnalyzerCommand::log_message(ctrlcon, "No reputation instance configured to reload\n"); + return 0; +} + +const Command reputation_cmds[] = +{ + {"reload", reload, nullptr, "reload reputation data"}, + {nullptr, nullptr, nullptr, nullptr} +}; diff --git a/src/network_inspectors/reputation/reputation_commands.h b/src/network_inspectors/reputation/reputation_commands.h new file mode 100644 index 000000000..15a1b4b84 --- /dev/null +++ b/src/network_inspectors/reputation/reputation_commands.h @@ -0,0 +1,28 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2021-2022 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// reputation_commands.h author Ron Dempster + +#ifndef REPUTATION_COMMANDS_H +#define REPUTATION_COMMANDS_H + +#include "framework/module.h" + +extern const snort::Command reputation_cmds[]; + +#endif + diff --git a/src/network_inspectors/reputation/reputation_config.h b/src/network_inspectors/reputation/reputation_config.h index e92addbf5..99a8fcfc7 100644 --- a/src/network_inspectors/reputation/reputation_config.h +++ b/src/network_inspectors/reputation/reputation_config.h @@ -80,20 +80,13 @@ typedef std::vector ListFiles; struct ReputationConfig { uint32_t memcap = 500; - int num_entries = 0; bool scanlocal = false; IPdecision priority = TRUSTED; NestedIP nested_ip = INNER; AllowAction allow_action = DO_NOT_BLOCK; std::string blocklist_path; std::string allowlist_path; - bool memcap_reached = false; - uint8_t* reputation_segment = nullptr; - table_flat_t* ip_list = nullptr; - ListFiles list_files; std::string list_dir; - - ~ReputationConfig(); }; struct IPrepInfo diff --git a/src/network_inspectors/reputation/reputation_inspect.cc b/src/network_inspectors/reputation/reputation_inspect.cc index f2f3f54e7..d41f28b6e 100644 --- a/src/network_inspectors/reputation/reputation_inspect.cc +++ b/src/network_inspectors/reputation/reputation_inspect.cc @@ -29,11 +29,15 @@ #include "detection/detection_engine.h" #include "events/event_queue.h" #include "log/messages.h" +#include "main/snort.h" +#include "main/snort_config.h" +#include "managers/inspector_manager.h" #include "network_inspectors/packet_tracer/packet_tracer.h" #include "packet_io/active.h" #include "profiler/profiler.h" #include "protocols/packet.h" #include "pub_sub/auxiliary_ip_event.h" +#include "utils/util.h" #include "reputation_parse.h" @@ -55,52 +59,28 @@ const PegInfo reputation_peg_names[] = { CountType::END, nullptr, nullptr } }; -const char* NestedIPKeyword[] = -{ - "inner", - "outer", - "all", - nullptr -}; +#define MANIFEST_FILENAME "interface.info" -const char* AllowActionOption[] = +static inline IPrepInfo* reputation_lookup(const ReputationConfig& config, + ReputationData& data, const SfIp* ip) { - "do_not_block", - "trust", - nullptr -}; - -/* - * Function prototype(s) - */ -static void snort_reputation(ReputationConfig* GlobalConf, Packet* p); -static void populate_trace_data(IPdecision& decision, Packet* p); - -static inline IPrepInfo* reputation_lookup(ReputationConfig* config, const SfIp* ip) -{ - IPrepInfo* result; - - if (!config->scanlocal) + if (!config.scanlocal) { if (ip->is_private() ) - { return nullptr; - } } - result = (IPrepInfo*)sfrt_flat_dir8x_lookup(ip, config->ip_list); - - return (result); + return (IPrepInfo*)sfrt_flat_dir8x_lookup(ip, data.ip_list); } -static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep_info, - uint32_t* listid, uint32_t ingress_intf, uint32_t egress_intf) +static inline IPdecision get_reputation(const ReputationConfig& config, ReputationData& data, + IPrepInfo* rep_info, uint32_t* listid, uint32_t ingress_intf, uint32_t egress_intf) { IPdecision decision = DECISION_NULL; /*Walk through the IPrepInfo lists*/ - uint8_t* base = (uint8_t*)config->ip_list; - ListFiles& list_info = config->list_files; + uint8_t* base = (uint8_t*)data.ip_list; + ListFiles& list_info = data.list_files; while (rep_info) { @@ -117,7 +97,7 @@ static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep { if (TRUSTED_DO_NOT_BLOCK == (IPdecision)list_info[list_index]->list_type) return DECISION_NULL; - if (config->priority == (IPdecision)list_info[list_index]->list_type ) + if (config.priority == (IPdecision)list_info[list_index]->list_type ) { *listid = list_info[list_index]->list_id; return ((IPdecision)list_info[list_index]->list_type); @@ -138,18 +118,16 @@ static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep return decision; } -static bool decision_per_layer(ReputationConfig* config, Packet* p, - uint32_t ingress_intf, uint32_t egress_intf, const ip::IpApi& ip_api, IPdecision* decision_final) +static bool decision_per_layer(const ReputationConfig& config, ReputationData& data, + Packet* p, uint32_t ingress_intf, uint32_t egress_intf, const ip::IpApi& ip_api, + IPdecision* decision_final) { - const SfIp* ip; - IPdecision decision; - IPrepInfo* result; - - ip = ip_api.get_src(); - result = reputation_lookup(config, ip); + const SfIp* ip = ip_api.get_src(); + IPrepInfo* result = reputation_lookup(config, data, ip); if (result) { - decision = get_reputation(config, result, &p->iplist_id, ingress_intf, egress_intf); + IPdecision decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf, + egress_intf); if (decision == BLOCKED) *decision_final = BLOCKED_SRC; @@ -160,15 +138,16 @@ static bool decision_per_layer(ReputationConfig* config, Packet* p, else *decision_final = decision; - if ( config->priority == decision) + if ( config.priority == decision) return true; } ip = ip_api.get_dst(); - result = reputation_lookup(config, ip); + result = reputation_lookup(config, data, ip); if (result) { - decision = get_reputation(config, result, &p->iplist_id, ingress_intf, egress_intf); + IPdecision decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf, + egress_intf); if (decision == BLOCKED) *decision_final = BLOCKED_DST; @@ -179,14 +158,15 @@ static bool decision_per_layer(ReputationConfig* config, Packet* p, else *decision_final = decision; - if ( config->priority == decision) + if ( config.priority == decision) return true; } return false; } -static IPdecision reputation_decision(ReputationConfig* config, Packet* p) +static IPdecision reputation_decision(const ReputationConfig& config, ReputationData& data, + Packet* p) { IPdecision decision_final = DECISION_NULL; uint32_t ingress_intf = 0; @@ -201,9 +181,9 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p) egress_intf = p->pkth->egress_index; } - if (config->nested_ip == INNER) + if (config.nested_ip == INNER) { - decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final); + decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final); return decision_final; } @@ -213,19 +193,19 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p) int8_t num_layer = 0; IpProtocol tmp_next = p->get_ip_proto_next(); - if (config->nested_ip == OUTER) + if (config.nested_ip == OUTER) { layer::set_outer_ip_api(p, p->ptrs.ip_api, p->ip_proto_next, num_layer); - decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final); + decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final); } - else if (config->nested_ip == ALL) + else if (config.nested_ip == ALL) { bool done = false; IPdecision decision_current = DECISION_NULL; while (!done and layer::set_outer_ip_api(p, p->ptrs.ip_api, p->ip_proto_next, num_layer)) { - done = decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api, + done = decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_current); if (decision_current != DECISION_NULL) { @@ -241,18 +221,19 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p) if (decision_final != BLOCKED_SRC and decision_final != BLOCKED_DST) p->ptrs.ip_api = tmp_api; - else if (config->nested_ip == ALL and p->ptrs.ip_api != blocked_api) + else if (config.nested_ip == ALL and p->ptrs.ip_api != blocked_api) p->ptrs.ip_api = blocked_api; p->ip_proto_next = tmp_next; return decision_final; } -static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, const SfIp* ip) +static IPdecision snort_reputation_aux_ip(const ReputationConfig& config, ReputationData& data, + Packet* p, const SfIp* ip) { IPdecision decision = DECISION_NULL; - if (!config->ip_list) + if (!data.ip_list) return decision; uint32_t ingress_intf = 0; @@ -267,10 +248,10 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c egress_intf = p->pkth->egress_index; } - IPrepInfo* result = reputation_lookup(config, ip); + IPrepInfo* result = reputation_lookup(config, data, ip); if (result) { - decision = get_reputation(config, result, &p->iplist_id, ingress_intf, + decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf, egress_intf); if (decision == BLOCKED) @@ -279,7 +260,7 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c p->flow->flags.reputation_blocklist = true; // Prior to IPRep logging, IPS policy must be set to the default policy, - set_ips_policy(SnortConfig::get_conf(), 0); + set_ips_policy(get_default_ips_policy(SnortConfig::get_conf())); DetectionEngine::queue_event(GID_REPUTATION, REPUTATION_EVENT_BLOCKLIST_DST); p->active->drop_packet(p, true); @@ -321,14 +302,67 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c return decision; } -static void snort_reputation(ReputationConfig* config, Packet* p) +static const char* to_string(IPdecision ipd) +{ + switch (ipd) + { + case BLOCKED: + return "blocked"; + case TRUSTED: + return "trusted"; + case MONITORED: + return "monitored"; + case BLOCKED_SRC: + return "blocked_src"; + case BLOCKED_DST: + return "blocked_dst"; + case TRUSTED_SRC: + return "trusted_src"; + case TRUSTED_DST: + return "trusted_dst"; + case TRUSTED_DO_NOT_BLOCK: + return "trusted_do_not_block"; + case MONITORED_SRC: + return "monitored_src"; + case MONITORED_DST: + return "monitored_dst"; + case DECISION_NULL: + case DECISION_MAX: + default: + return ""; + } +} + +static void populate_trace_data(IPdecision& decision, Packet* p) +{ + char addr[INET6_ADDRSTRLEN]; + const SfIp* ip = nullptr; + + if (BLOCKED_SRC == decision or MONITORED_SRC == decision or TRUSTED_SRC == decision) + { + ip = p->ptrs.ip_api.get_src(); + } + else if (BLOCKED_DST == decision or MONITORED_DST == decision or TRUSTED_DST == decision) + { + ip = p->ptrs.ip_api.get_dst(); + } + + sfip_ntop(ip, addr, sizeof(addr)); + + PacketTracer::daq_log("SI-IP+%" PRId64"+%s list id %u+Matched ip %s, action %s$", + TO_NSECS(pt_timer->get()), + (TRUSTED_SRC == decision or TRUSTED_DST == decision)?"Do_not_block":"Block", + p->iplist_id, addr, to_string(decision)); +} + +static void snort_reputation(const ReputationConfig& config, ReputationData& data, Packet* p) { IPdecision decision; - if (!config->ip_list) + if (!data.ip_list) return; - decision = reputation_decision(config, p); + decision = reputation_decision(config, data, p); Active* act = p->active; if (BLOCKED_SRC == decision or BLOCKED_DST == decision) @@ -364,7 +398,7 @@ static void snort_reputation(ReputationConfig* config, Packet* p) const auto& aux_ip_list = p->flow->stash->get_aux_ip_list(); for ( const auto& ip : aux_ip_list ) { - if ( BLOCKED == snort_reputation_aux_ip(config, p, &ip) ) + if ( BLOCKED == snort_reputation_aux_ip(config, data, p, &ip) ) return; } } @@ -437,73 +471,23 @@ static const char* to_string(AllowAction aa) return ""; } -static const char* to_string(IPdecision ipd) -{ - switch (ipd) - { - case BLOCKED: - return "blocked"; - case TRUSTED: - return "trusted"; - case MONITORED: - return "monitored"; - case BLOCKED_SRC: - return "blocked_src"; - case BLOCKED_DST: - return "blocked_dst"; - case TRUSTED_SRC: - return "trusted_src"; - case TRUSTED_DST: - return "trusted_dst"; - case TRUSTED_DO_NOT_BLOCK: - return "trusted_do_not_block"; - case MONITORED_SRC: - return "monitored_src"; - case MONITORED_DST: - return "monitored_dst"; - case DECISION_NULL: - case DECISION_MAX: - default: - return ""; - } -} - -static void populate_trace_data(IPdecision& decision, Packet* p) -{ - char addr[INET6_ADDRSTRLEN]; - const SfIp* ip = nullptr; - - if (BLOCKED_SRC == decision or MONITORED_SRC == decision or TRUSTED_SRC == decision) - { - ip = p->ptrs.ip_api.get_src(); - } - else if (BLOCKED_DST == decision or MONITORED_DST == decision or TRUSTED_DST == decision) - { - ip = p->ptrs.ip_api.get_dst(); - } - - sfip_ntop(ip, addr, sizeof(addr)); - - PacketTracer::daq_log("SI-IP+%" PRId64"+%s list id %u+Matched ip %s, action %s$", - TO_NSECS(pt_timer->get()), - (TRUSTED_SRC == decision or TRUSTED_DST == decision)?"Do_not_block":"Block", - p->iplist_id, addr, to_string(decision)); -} - class AuxiliaryIpRepHandler : public DataHandler { public: - AuxiliaryIpRepHandler(ReputationConfig& rc) : DataHandler(REPUTATION_NAME), conf(rc) { } + explicit AuxiliaryIpRepHandler(Reputation& inspector) + : DataHandler(REPUTATION_NAME), inspector(inspector) + { } void handle(DataEvent&, Flow*) override; private: - ReputationConfig& conf; + Reputation& inspector; }; void AuxiliaryIpRepHandler::handle(DataEvent& event, Flow*) { Profile profile(reputation_perf_stats); - snort_reputation_aux_ip(&conf, DetectionEngine::get_current_packet(), + snort_reputation_aux_ip(inspector.get_config(), inspector.get_data(), + DetectionEngine::get_current_packet(), static_cast(&event)->get_ip()); } @@ -511,26 +495,58 @@ void AuxiliaryIpRepHandler::handle(DataEvent& event, Flow*) // class stuff //------------------------------------------------------------------------- -Reputation::Reputation(ReputationConfig* pc) +ReputationData::~ReputationData() { - config = *pc; - ReputationConfig* conf = &config; + if (reputation_segment) + snort_free(reputation_segment); + + for (auto& file : list_files) + delete file; +} + +Reputation::Reputation(ReputationConfig* pc) : config(*pc) +{ rep_data = load_data(); } + +Reputation::~Reputation() +{ delete rep_data; } + +ReputationData* Reputation::load_data() +{ + ReputationData* data = new ReputationData(); if (!config.list_dir.empty()) - read_manifest(MANIFEST_FILENAME, conf); + read_manifest(MANIFEST_FILENAME, config, *data); - add_block_allow_List(conf); - estimate_num_entries(conf); - if (conf->num_entries <= 0) + add_block_allow_List(config, *data); + estimate_num_entries(*data); + if (0 >= data->num_entries) { ParseWarning(WARN_CONF, "reputation: can't find any allowlist/blocklist entries; disabled."); - return; } + else + { + ip_list_init(data->num_entries + 1, config, *data); + reputationstats.memory_allocated = sfrt_flat_usage(data->ip_list); + } + + return data; +} + +void Reputation::swap_thread_data(ReputationData* data) +{ set_thread_specific_data(data); } - ip_list_init(conf->num_entries + 1, conf); - reputationstats.memory_allocated = sfrt_flat_usage(conf->ip_list); +void Reputation::swap_data(ReputationData* data) +{ + delete rep_data; + rep_data = data; } +void Reputation::tinit() +{ set_thread_specific_data(rep_data); } + +void Reputation::tterm() +{ set_thread_specific_data(nullptr); } + void Reputation::show(const SnortConfig*) const { ConfigLogger::log_value("blocklist", config.blocklist_path.c_str()); @@ -556,16 +572,21 @@ void Reputation::eval(Packet* p) if (PacketTracer::is_daq_activated()) PacketTracer::pt_timer_start(); - snort_reputation(&config, p); + ReputationData* data = static_cast(get_thread_specific_data()); + assert(data); + snort_reputation(config, *data, p); ++reputationstats.packets; } bool Reputation::configure(SnortConfig*) { - DataBus::subscribe_network( AUXILIARY_IP_EVENT, new AuxiliaryIpRepHandler(config) ); + DataBus::subscribe_network( AUXILIARY_IP_EVENT, new AuxiliaryIpRepHandler(*this) ); return true; } +void Reputation::install_reload_handler(SnortConfig* sc) +{ sc->register_reload_handler(new ReputationReloadSwapper(*this)); } + //------------------------------------------------------------------------- // api stuff //------------------------------------------------------------------------- diff --git a/src/network_inspectors/reputation/reputation_inspect.h b/src/network_inspectors/reputation/reputation_inspect.h index 715523905..9ba573abc 100644 --- a/src/network_inspectors/reputation/reputation_inspect.h +++ b/src/network_inspectors/reputation/reputation_inspect.h @@ -19,21 +19,49 @@ #ifndef REPUTATION_INSPECT_H #define REPUTATION_INSPECT_H -#include "flow/flow.h" +#include "framework/inspector.h" #include "reputation_module.h" +class ReputationData +{ +public: + ReputationData() = default; + ~ReputationData(); + + ListFiles list_files; + uint8_t* reputation_segment = nullptr; + table_flat_t* ip_list = nullptr; + int num_entries = 0; + bool memcap_reached = false; +}; + class Reputation : public snort::Inspector { public: - Reputation(ReputationConfig*); + explicit Reputation(ReputationConfig*); + ~Reputation() override; + + void tinit() override; + void tterm() override; void show(const snort::SnortConfig*) const override; void eval(snort::Packet*) override; bool configure(snort::SnortConfig*) override; + void install_reload_handler(snort::SnortConfig*) override; + + ReputationData& get_data() + { return *rep_data; } + const ReputationConfig& get_config() + { return config; } + ReputationData* load_data(); + + void swap_thread_data(ReputationData*); + void swap_data(ReputationData*); private: ReputationConfig config; + ReputationData* rep_data; }; #endif diff --git a/src/network_inspectors/reputation/reputation_module.cc b/src/network_inspectors/reputation/reputation_module.cc index 5bcd6541c..e130cc3af 100644 --- a/src/network_inspectors/reputation/reputation_module.cc +++ b/src/network_inspectors/reputation/reputation_module.cc @@ -29,6 +29,8 @@ #include "log/messages.h" #include "utils/util.h" +#include "reputation_commands.h" +#include "reputation_inspect.h" #include "reputation_parse.h" using namespace snort; @@ -109,6 +111,9 @@ ReputationModule::~ReputationModule() const RuleMap* ReputationModule::get_rules() const { return reputation_rules; } +const Command* ReputationModule::get_commands() const +{ return reputation_cmds; } + const PegInfo* ReputationModule::get_pegs() const { return reputation_peg_names; } @@ -175,3 +180,7 @@ bool ReputationModule::end(const char*, int, SnortConfig*) return true; } + +void ReputationReloadSwapper::tswap() +{ inspector.set_thread_specific_data(&inspector.get_data()); } + diff --git a/src/network_inspectors/reputation/reputation_module.h b/src/network_inspectors/reputation/reputation_module.h index 5b221d3a8..6f8a53961 100644 --- a/src/network_inspectors/reputation/reputation_module.h +++ b/src/network_inspectors/reputation/reputation_module.h @@ -24,6 +24,8 @@ // Interface to the REPUTATION network inspector #include "framework/module.h" +#include "main/reload_tuner.h" + #include "reputation_config.h" #include "reputation_common.h" @@ -36,6 +38,21 @@ extern THREAD_LOCAL snort::ProfileStats reputation_perf_stats; extern unsigned long total_duplicates; extern unsigned long total_invalids; +class Reputation; + +class ReputationReloadSwapper : public snort::ReloadSwapper +{ +public: + explicit ReputationReloadSwapper(Reputation& ins) : inspector(ins) + { } + ~ReputationReloadSwapper() override = default; + + void tswap() override; + +private: + Reputation& inspector; +}; + class ReputationModule : public snort::Module { public: @@ -50,6 +67,7 @@ public: { return GID_REPUTATION; } const snort::RuleMap* get_rules() const override; + const snort::Command* get_commands() const override; const PegInfo* get_pegs() const override; PegCount* get_counts() const override; snort::ProfileStats* get_profile() const override; diff --git a/src/network_inspectors/reputation/reputation_parse.cc b/src/network_inspectors/reputation/reputation_parse.cc index d6e038942..d1df88681 100644 --- a/src/network_inspectors/reputation/reputation_parse.cc +++ b/src/network_inspectors/reputation/reputation_parse.cc @@ -36,6 +36,9 @@ #include "utils/util.h" #include "utils/util_cstring.h" +#include "reputation_config.h" +#include "reputation_inspect.h" + using namespace snort; using namespace std; @@ -76,19 +79,6 @@ unsigned long total_invalids; int totalNumEntries = 0; -static void load_list_file(ListFile*, ReputationConfig* config); - -ReputationConfig::~ReputationConfig() -{ - if (reputation_segment != nullptr) - snort_free(reputation_segment); - - for (auto& file : list_files) - { - delete file; - } -} - static uint32_t estimate_size(uint32_t num_entries, uint32_t memcap) { uint64_t size; @@ -114,49 +104,6 @@ static uint32_t estimate_size(uint32_t num_entries, uint32_t memcap) return (uint32_t)size; } -void ip_list_init(uint32_t max_entries, ReputationConfig* config) -{ - if ( !config->ip_list ) - { - uint32_t mem_size; - mem_size = estimate_size(max_entries, config->memcap); - config->reputation_segment = (uint8_t*)snort_alloc(mem_size); - - segment_meminit(config->reputation_segment, mem_size); - - /*DIR_16x7_4x4 for performance, but memory usage is high - *Use DIR_8x16 worst case IPV4 5K, IPV6 15K (bytes) - *Use DIR_16x7_4x4 worst case IPV4 500, IPV6 2.5M - */ - config->ip_list = sfrt_flat_new(DIR_8x16, IPv6, max_entries, config->memcap); - - if ( !config->ip_list ) - { - ErrorMessage("Failed to create IP list.\n"); - return; - } - - total_duplicates = 0; - for (size_t i = 0; i < config->list_files.size(); i++) - { - config->list_files[i]->list_index = (uint8_t)i + 1; - if (config->list_files[i]->file_type == ALLOW_LIST) - { - if (config->allow_action == DO_NOT_BLOCK) - config->list_files[i]->list_type = TRUSTED_DO_NOT_BLOCK; - else - config->list_files[i]->list_type = TRUSTED; - } - else if (config->list_files[i]->file_type == BLOCK_LIST) - config->list_files[i]->list_type = BLOCKED; - else if (config->list_files[i]->file_type == MONITOR_LIST) - config->list_files[i]->list_type = MONITORED; - - load_list_file(config->list_files[i], config); - } - } -} - static inline IPrepInfo* get_last_index(IPrepInfo* rep_info, uint8_t* base, int* last_index) { int i; @@ -311,51 +258,39 @@ static int64_t update_entry_info(INFO* current, INFO new_entry, SaveDest save_de return bytes_allocated; } -static int add_ip(SfCidr* ip_addr,INFO info_ptr, ReputationConfig* config) +static int add_ip(SfCidr* ip_addr,INFO info_ptr, const ReputationConfig& config, + ReputationData& data) { - int ret; - int final_ret = IP_INSERT_SUCCESS; /*This variable is used to check whether a more generic address * overrides specific address */ uint32_t usage_before; uint32_t usage_after; - usage_before = sfrt_flat_usage(config->ip_list); + usage_before = sfrt_flat_usage(data.ip_list); + int final_ret = IP_INSERT_SUCCESS; /*Check whether the same or more generic address is already in the table*/ - if (nullptr != sfrt_flat_lookup(ip_addr->get_addr(), config->ip_list)) - { + if (nullptr != sfrt_flat_lookup(ip_addr->get_addr(), data.ip_list)) final_ret = IP_INSERT_DUPLICATE; - } - ret = sfrt_flat_insert(ip_addr, (unsigned char)ip_addr->get_bits(), info_ptr, RT_FAVOR_ALL, - config->ip_list, &update_entry_info); + int ret = sfrt_flat_insert(ip_addr, (unsigned char)ip_addr->get_bits(), info_ptr, RT_FAVOR_ALL, + data.ip_list, &update_entry_info); if (RT_SUCCESS == ret) - { totalNumEntries++; - } else if (MEM_ALLOC_FAILURE == ret) - { final_ret = IP_MEM_ALLOC_FAILURE; - } else - { final_ret = IP_INSERT_FAILURE; - } - usage_after = sfrt_flat_usage(config->ip_list); + usage_after = sfrt_flat_usage(data.ip_list); /*Compare in the same scale*/ - if (usage_after > (config->memcap << 20)) - { + if (usage_after > (config.memcap << 20)) final_ret = IP_MEM_ALLOC_FAILURE; - } /*Check whether there a more specific address will be overridden*/ if (usage_before > usage_after ) - { final_ret = IP_INSERT_DUPLICATE; - } return final_ret; } @@ -499,7 +434,8 @@ static int snort_pton(char const* src, SfCidr* dest) return 1; } -static int process_line(char* line, INFO info, ReputationConfig* config) +static int process_line(char* line, INFO info, const ReputationConfig& config, + ReputationData& data) { SfCidr address; @@ -509,7 +445,7 @@ static int process_line(char* line, INFO info, ReputationConfig* config) if ( snort_pton(line, &address) < 1 ) return IP_INVALID; - return add_ip(&address, info, config); + return add_ip(&address, info, config, data); } static int update_path_to_file(char* full_filename, unsigned int max_size, const char* filename) @@ -571,7 +507,8 @@ static char* get_list_type_name(ListFile* list_info) } } -static void load_list_file(ListFile* list_info, ReputationConfig* config) +static void load_list_file(ListFile* list_info, const ReputationConfig& config, + ReputationData& data) { char linebuf[MAX_ADDR_LINE_LENGTH]; char full_path_filename[PATH_MAX+1]; @@ -589,7 +526,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) unsigned int fail_count = 0; /*number of invalid entries in this file*/ unsigned int num_loaded_before = 0; /*number of valid entries loaded */ - if (config->memcap_reached) + if (data.memcap_reached) return; update_path_to_file(full_path_filename, PATH_MAX, list_info->file_name.c_str()); @@ -602,10 +539,8 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) /*convert list info to ip entry info*/ ip_info_ptr = segment_snort_calloc(1,sizeof(IPrepInfo)); if (!(ip_info_ptr)) - { return; - } - base = (uint8_t*)config->ip_list; + base = (uint8_t*)data.ip_list; ip_info = ((IPrepInfo*)&base[ip_info_ptr]); ip_info->list_indexes[0] = list_info->list_index; @@ -618,7 +553,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) return; } - num_loaded_before = sfrt_flat_num_entries(config->ip_list); + num_loaded_before = sfrt_flat_num_entries(data.ip_list); while ( fgets(linebuf, MAX_ADDR_LINE_LENGTH, fp) ) { int ret; @@ -633,7 +568,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) *cmt = '\0'; /* process the line */ - ret = process_line(linebuf, ip_info_ptr, config); + ret = process_line(linebuf, ip_info_ptr, config, data); if (IP_INSERT_SUCCESS == ret) { @@ -655,9 +590,9 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) { ErrorMessage( "WARNING: %s(%d) => Memcap %u Mbytes reached when inserting IP Address: %s\n", - full_path_filename, addrline, config->memcap,linebuf); + full_path_filename, addrline, config.memcap,linebuf); - config->memcap_reached = true; + data.memcap_reached = true; break; } } @@ -673,12 +608,55 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config) ErrorMessage(" Additional duplicate addresses were not listed.\n"); LogMessage(" Reputation entries loaded: %u, invalid: %u, re-defined: %u (from file %s)\n", - sfrt_flat_num_entries(config->ip_list) - num_loaded_before, + sfrt_flat_num_entries(data.ip_list) - num_loaded_before, invalid_count, duplicate_count, full_path_filename); fclose(fp); } +void ip_list_init(uint32_t max_entries, const ReputationConfig& config, ReputationData& data) +{ + if ( !data.ip_list ) + { + uint32_t mem_size; + mem_size = estimate_size(max_entries, config.memcap); + data.reputation_segment = (uint8_t*)snort_alloc(mem_size); + + segment_meminit(data.reputation_segment, mem_size); + + /*DIR_16x7_4x4 for performance, but memory usage is high + *Use DIR_8x16 worst case IPV4 5K, IPV6 15K (bytes) + *Use DIR_16x7_4x4 worst case IPV4 500, IPV6 2.5M + */ + data.ip_list = sfrt_flat_new(DIR_8x16, IPv6, max_entries, config.memcap); + + if ( !data.ip_list ) + { + ErrorMessage("Failed to create IP list.\n"); + return; + } + + total_duplicates = 0; + for (size_t i = 0; i < data.list_files.size(); i++) + { + data.list_files[i]->list_index = (uint8_t)i + 1; + if (data.list_files[i]->file_type == ALLOW_LIST) + { + if (config.allow_action == DO_NOT_BLOCK) + data.list_files[i]->list_type = TRUSTED_DO_NOT_BLOCK; + else + data.list_files[i]->list_type = TRUSTED; + } + else if (data.list_files[i]->file_type == BLOCK_LIST) + data.list_files[i]->list_type = BLOCKED; + else if (data.list_files[i]->file_type == MONITOR_LIST) + data.list_files[i]->list_type = MONITORED; + + load_list_file(data.list_files[i], config, data); + } + } +} + static int num_lines_in_file(char* fname) { FILE* fp; @@ -735,37 +713,33 @@ static int load_file(int total_lines, const char* path) return num_lines; } -void estimate_num_entries(ReputationConfig* config) +void estimate_num_entries(ReputationData& data) { - int total_lines = 0; + data.num_entries = 0; - for (auto& file : config->list_files) - { - total_lines += load_file(total_lines, file->file_name.c_str()); - } - - config->num_entries = total_lines; + for (auto& file : data.list_files) + data.num_entries += load_file(data.num_entries, file->file_name.c_str()); } -void add_block_allow_List(ReputationConfig* config) +void add_block_allow_List(const ReputationConfig& config, ReputationData& data) { - if (config->blocklist_path.size()) + if (config.blocklist_path.size()) { ListFile* listItem = new ListFile; listItem->all_intfs_enabled = true; - listItem->file_name = config->blocklist_path; + listItem->file_name = config.blocklist_path; listItem->file_type = BLOCK_LIST; listItem->list_id = 0; - config->list_files.emplace_back(listItem); + data.list_files.emplace_back(listItem); } - if (config->allowlist_path.size()) + if (config.allowlist_path.size()) { ListFile* listItem = new ListFile; listItem->all_intfs_enabled = true; - listItem->file_name = config->allowlist_path; + listItem->file_name = config.allowlist_path; listItem->file_type = ALLOW_LIST; listItem->list_id = 0; - config->list_files.emplace_back(listItem); + data.list_files.emplace_back(listItem); } } @@ -828,7 +802,7 @@ static int get_file_type(char* type_name) //If no interface information provided, this means all interfaces are applied. static bool process_line_in_manifest(ListFile* list_item, const char* manifest, const char* line, - int line_number, ReputationConfig* config) + int line_number, const ReputationConfig& config, ReputationData& data) { char* token; int token_index = 0; @@ -846,7 +820,7 @@ static bool process_line_in_manifest(ListFile* list_item, const char* manifest, switch (token_index) { case 0: // File name - list_item->file_name = config->list_dir + '/' + token; + list_item->file_name = config.list_dir + '/' + token; break; case 1: // List ID @@ -925,17 +899,14 @@ static bool process_line_in_manifest(ListFile* list_item, const char* manifest, list_item->all_intfs_enabled = true; } - config->list_files.emplace_back(list_item); + data.list_files.emplace_back(list_item); return true; } -int read_manifest(const char* manifest_file, ReputationConfig* config) +void read_manifest(const char* manifest_file, const ReputationConfig& config, ReputationData& data) { - int line_number = 0; - std::string line; char full_path_dir[PATH_MAX+1]; - - update_path_to_file(full_path_dir, PATH_MAX, config->list_dir.c_str()); + update_path_to_file(full_path_dir, PATH_MAX, config.list_dir.c_str()); std::string manifest_full_path = std::string(full_path_dir) + '/' + manifest_file; std::fstream fs; @@ -944,9 +915,11 @@ int read_manifest(const char* manifest_file, ReputationConfig* config) if (!fs.good()) { ErrorMessage("Can't open file: %s\n", manifest_full_path.c_str()); - return -1; + return; } + int line_number = 0; + std::string line; while (std::getline(fs, line)) { line_number++; @@ -958,12 +931,11 @@ int read_manifest(const char* manifest_file, ReputationConfig* config) //Processing the line ListFile* list_item = new ListFile; - if (!process_line_in_manifest(list_item, manifest_file, line.c_str(), line_number, config)) + if (!process_line_in_manifest( + list_item, manifest_file, line.c_str(), line_number, config, data)) delete list_item; } fs.close(); - - return 0; } diff --git a/src/network_inspectors/reputation/reputation_parse.h b/src/network_inspectors/reputation/reputation_parse.h index 6c354a416..33ba33443 100644 --- a/src/network_inspectors/reputation/reputation_parse.h +++ b/src/network_inspectors/reputation/reputation_parse.h @@ -20,13 +20,14 @@ #ifndef REPUTATION_PARSE_H #define REPUTATION_PARSE_H -#include "reputation_config.h" +#include -#define MANIFEST_FILENAME "interface.info" +struct ReputationConfig; +class ReputationData; -void ip_list_init(uint32_t,ReputationConfig *config); -void estimate_num_entries(ReputationConfig* config); -int read_manifest(const char* filename, ReputationConfig* config); -void add_block_allow_List(ReputationConfig* config); +void ip_list_init(uint32_t max_entries, const ReputationConfig&, ReputationData&); +void estimate_num_entries(ReputationData&); +void read_manifest(const char* filename, const ReputationConfig&, ReputationData&); +void add_block_allow_List(const ReputationConfig&, ReputationData&); #endif diff --git a/src/network_inspectors/rna/rna_inspector.cc b/src/network_inspectors/rna/rna_inspector.cc index ce8c154bd..08ea9977e 100644 --- a/src/network_inspectors/rna/rna_inspector.cc +++ b/src/network_inspectors/rna/rna_inspector.cc @@ -86,7 +86,7 @@ RnaInspector::~RnaInspector() } } -bool RnaInspector::configure(SnortConfig* sc) +bool RnaInspector::configure(SnortConfig*) { DataBus::subscribe_network( APPID_EVENT_ANY_CHANGE, new RnaAppidEventHandler(*pnd) ); DataBus::subscribe_network( DHCP_INFO_EVENT, new RnaDHCPInfoEventHandler(*pnd) ); @@ -110,13 +110,12 @@ bool RnaInspector::configure(SnortConfig* sc) if (rna_conf && rna_conf->log_when_idle) DataBus::subscribe_network( THREAD_IDLE_EVENT, new RnaIdleEventHandler(*pnd) ); - // tinit is not called during reload, so pass processor pointers to threads via reload tuner - if ( Snort::is_reloading() && InspectorManager::get_inspector(RNA_NAME, true) ) - sc->register_reload_resource_tuner(new FpProcReloadTuner(*mod_conf)); - return true; } +void RnaInspector::install_reload_handler(SnortConfig* sc) +{ sc->register_reload_handler(new FpProcReloadTuner(*mod_conf)); } + void RnaInspector::eval(Packet* p) { Profile profile(rna_perf_stats); diff --git a/src/network_inspectors/rna/rna_inspector.h b/src/network_inspectors/rna/rna_inspector.h index a3a8e45c3..c5774a796 100644 --- a/src/network_inspectors/rna/rna_inspector.h +++ b/src/network_inspectors/rna/rna_inspector.h @@ -46,6 +46,7 @@ public: ~RnaInspector() override; bool configure(snort::SnortConfig*) override; + void install_reload_handler(snort::SnortConfig*) override; void eval(snort::Packet*) override; void show(const snort::SnortConfig*) const override; void tinit() override; diff --git a/src/network_inspectors/rna/rna_module.h b/src/network_inspectors/rna/rna_module.h index 21d8bb2c3..d5c98d80e 100644 --- a/src/network_inspectors/rna/rna_module.h +++ b/src/network_inspectors/rna/rna_module.h @@ -22,7 +22,7 @@ #define RNA_MODULE_H #include "framework/module.h" -#include "main/snort_config.h" +#include "main/reload_tuner.h" #include "main/snort_debug.h" #include "profiler/profiler.h" @@ -31,6 +31,11 @@ #include "rna_mac_cache.h" #include "rna_name.h" +namespace snort +{ +struct SnortConfig; +} + struct RnaStats { PegCount appid_change; diff --git a/src/network_inspectors/rna/test/rna_module_stubs.h b/src/network_inspectors/rna/test/rna_module_stubs.h index a032bd2e7..f8dd9f8fa 100644 --- a/src/network_inspectors/rna/test/rna_module_stubs.h +++ b/src/network_inspectors/rna/test/rna_module_stubs.h @@ -38,6 +38,8 @@ void Module::show_interval_stats(std::vector policy_selections; std::unordered_map policy_map; }; @@ -113,18 +115,18 @@ void AddressSpaceSelector::show() const } } -bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc) +bool AddressSpaceSelector::select_default_policies(uint32_t key, const SnortConfig* sc) { Profile profile(address_space_selectPerfStats); address_space_select_stats.packets++; - auto i = policy_map.find(static_cast(pkthdr->address_space_id)); + auto i = policy_map.find(key); if (i != policy_map.end()) { auto use = (*i).second; - set_network_policy(sc, use->network_index); - set_inspection_policy(sc, use->inspection_index); + set_network_policy(use->network_index); + set_inspection_policy(use->inspection_index); set_ips_policy(sc, use->ips_index); return true; } @@ -132,6 +134,18 @@ bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, c return false; } +bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr& pkthdr, + const SnortConfig* sc) +{ + return select_default_policies(static_cast(pkthdr.address_space_id), sc); +} + +bool AddressSpaceSelector::select_default_policies(const _daq_flow_stats& stats, + const SnortConfig* sc) +{ + return select_default_policies(static_cast(stats.address_space_id), sc); +} + //------------------------------------------------------------------------- // api stuff //------------------------------------------------------------------------- diff --git a/src/policy_selectors/address_space_selector/address_space_selector_module.cc b/src/policy_selectors/address_space_selector/address_space_selector_module.cc index 577f5ba0c..6ea960f7e 100644 --- a/src/policy_selectors/address_space_selector/address_space_selector_module.cc +++ b/src/policy_selectors/address_space_selector/address_space_selector_module.cc @@ -125,7 +125,7 @@ bool AddressSpaceSelectorModule::end(const char* fqn, int idx, SnortConfig* sc) } Shell* sh = new Shell(policy_filename.c_str()); - auto policies = sc->policy_map->add_shell(sh, true); + auto policies = sc->policy_map->add_shell(sh, nullptr); selection.use.network_index = policies->network->policy_id; selection.use.inspection_index = policies->inspection->policy_id; selection.use.ips_index = policies->ips->policy_id; diff --git a/src/policy_selectors/tenant_selector/tenant_selector.cc b/src/policy_selectors/tenant_selector/tenant_selector.cc index 2dba08c4b..ce3646a71 100644 --- a/src/policy_selectors/tenant_selector/tenant_selector.cc +++ b/src/policy_selectors/tenant_selector/tenant_selector.cc @@ -67,9 +67,11 @@ public: void show() const override; - bool select_default_policies(const _daq_pkt_hdr*, const SnortConfig*) override; + bool select_default_policies(const _daq_pkt_hdr&, const SnortConfig*) override; + bool select_default_policies(const _daq_flow_stats&, const SnortConfig*) override; protected: + bool select_default_policies(uint32_t key, const SnortConfig*); std::vector policy_selections; std::unordered_map policy_map; }; @@ -113,19 +115,18 @@ void TenantSelector::show() const } } -bool TenantSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc) +bool TenantSelector::select_default_policies(uint32_t key, const SnortConfig* sc) { Profile profile(tenant_select_perf_stats); tenant_select_stats.packets++; - // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr - auto i = policy_map.find(static_cast(pkthdr->address_space_id)); + auto i = policy_map.find(key); if (i != policy_map.end()) { auto use = (*i).second; - set_network_policy(sc, use->network_index); - set_inspection_policy(sc, use->inspection_index); + set_network_policy(use->network_index); + set_inspection_policy(use->inspection_index); set_ips_policy(sc, use->ips_index); return true; } @@ -133,6 +134,18 @@ bool TenantSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const S return false; } +bool TenantSelector::select_default_policies(const _daq_pkt_hdr& pkthdr, const SnortConfig* sc) +{ + // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr + return select_default_policies(static_cast(pkthdr.address_space_id), sc); +} + +bool TenantSelector::select_default_policies(const _daq_flow_stats& stats, const SnortConfig* sc) +{ + // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr + return select_default_policies(static_cast(stats.address_space_id), sc); +} + //------------------------------------------------------------------------- // api stuff //------------------------------------------------------------------------- diff --git a/src/policy_selectors/tenant_selector/tenant_selector_module.cc b/src/policy_selectors/tenant_selector/tenant_selector_module.cc index 2f6ea877c..bf73623fa 100644 --- a/src/policy_selectors/tenant_selector/tenant_selector_module.cc +++ b/src/policy_selectors/tenant_selector/tenant_selector_module.cc @@ -110,7 +110,7 @@ bool TenantSelectorModule::end(const char* fqn, int idx, SnortConfig* sc) } Shell* sh = new Shell(policy_filename.c_str()); - auto policies = sc->policy_map->add_shell(sh, true); + auto policies = sc->policy_map->add_shell(sh, nullptr); selection.use.network_index = policies->network->policy_id; selection.use.inspection_index = policies->inspection->policy_id; selection.use.ips_index = policies->ips->policy_id; diff --git a/src/pub_sub/opportunistic_tls_event.h b/src/pub_sub/opportunistic_tls_event.h index ecaa39eed..9aafeebd3 100644 --- a/src/pub_sub/opportunistic_tls_event.h +++ b/src/pub_sub/opportunistic_tls_event.h @@ -20,6 +20,9 @@ #ifndef OPPORTUNISTIC_TLS_EVENT_H #define OPPORTUNISTIC_TLS_EVENT_H +#include +#include + #include "framework/data_bus.h" // An opportunistic SSL/TLS session will start from next packet @@ -31,18 +34,18 @@ namespace snort class SO_PUBLIC OpportunisticTlsEvent : public snort::DataEvent { public: - OpportunisticTlsEvent(const snort::Packet* p, const char* service) : + OpportunisticTlsEvent(const snort::Packet* p, std::shared_ptr service) : pkt(p), next_service(service) { } const snort::Packet* get_packet() override { return pkt; } - const char* get_next_service() + std::shared_ptr get_next_service() { return next_service; } private: const snort::Packet* pkt; - const char* next_service = nullptr; + std::shared_ptr next_service; }; } diff --git a/src/search_engines/test/hyperscan_test.cc b/src/search_engines/test/hyperscan_test.cc index fdee055b2..6ef6332cf 100644 --- a/src/search_engines/test/hyperscan_test.cc +++ b/src/search_engines/test/hyperscan_test.cc @@ -101,6 +101,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static std::vector s_state; static ScratchAllocator* scratcher = nullptr; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + SnortConfig::SnortConfig(const SnortConfig* const, const char*) { state = &s_state; diff --git a/src/search_engines/test/search_tool_test.cc b/src/search_engines/test/search_tool_test.cc index 2932653d1..c19bbece0 100644 --- a/src/search_engines/test/search_tool_test.cc +++ b/src/search_engines/test/search_tool_test.cc @@ -54,6 +54,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static std::vector s_state; +DataBus::DataBus() = default; +DataBus::~DataBus() = default; + SnortConfig::SnortConfig(const SnortConfig* const, const char*) { state = &s_state; diff --git a/src/service_inspectors/dce_rpc/dce_common.cc b/src/service_inspectors/dce_rpc/dce_common.cc index cee6e3be0..1f520fab2 100644 --- a/src/service_inspectors/dce_rpc/dce_common.cc +++ b/src/service_inspectors/dce_rpc/dce_common.cc @@ -43,6 +43,9 @@ using namespace snort; THREAD_LOCAL int dce2_detected = 0; static THREAD_LOCAL bool using_rpkt = false; +std::shared_ptr dce_rpc_service_name = + std::make_shared(DCE_RPC_SERVICE_NAME); + static const char* dce2_get_policy_name(DCE2_Policy policy) { const char* policyStr = nullptr; diff --git a/src/service_inspectors/dce_rpc/dce_common.h b/src/service_inspectors/dce_rpc/dce_common.h index 84d608330..18396850c 100644 --- a/src/service_inspectors/dce_rpc/dce_common.h +++ b/src/service_inspectors/dce_rpc/dce_common.h @@ -22,6 +22,8 @@ #define DCE_COMMON_H #include +#include +#include #include "detection/detection_engine.h" #include "framework/counts.h" @@ -41,6 +43,7 @@ extern THREAD_LOCAL int dce2_detected; #define GID_DCE2 133 #define DCE_RPC_SERVICE_NAME "dcerpc" +extern std::shared_ptr dce_rpc_service_name; enum DCE2_Policy { diff --git a/src/service_inspectors/dce_rpc/dce_http_proxy.cc b/src/service_inspectors/dce_rpc/dce_http_proxy.cc index 2cdaad9e4..7da11138e 100644 --- a/src/service_inspectors/dce_rpc/dce_http_proxy.cc +++ b/src/service_inspectors/dce_rpc/dce_http_proxy.cc @@ -68,7 +68,7 @@ void DceHttpProxy::clear(Packet* p) if ( c2s_splitter->cutover_inspector() && s2c_splitter->cutover_inspector() ) { dce_http_proxy_stats.http_proxy_sessions++; - flow->set_service(p, DCE_RPC_SERVICE_NAME); + flow->set_service(p, dce_rpc_service_name); } else dce_http_proxy_stats.http_proxy_session_failures++; diff --git a/src/service_inspectors/dce_rpc/dce_http_server.cc b/src/service_inspectors/dce_rpc/dce_http_server.cc index acf0126e7..557b035da 100644 --- a/src/service_inspectors/dce_rpc/dce_http_server.cc +++ b/src/service_inspectors/dce_rpc/dce_http_server.cc @@ -64,7 +64,7 @@ void DceHttpServer::clear(Packet* p) if ( splitter->cutover_inspector()) { dce_http_server_stats.http_server_sessions++; - flow->set_service(p, DCE_RPC_SERVICE_NAME); + flow->set_service(p, dce_rpc_service_name); } else dce_http_server_stats.http_server_session_failures++; diff --git a/src/service_inspectors/ftp_telnet/ftp_data.cc b/src/service_inspectors/ftp_telnet/ftp_data.cc index 10f194d4a..435d79507 100644 --- a/src/service_inspectors/ftp_telnet/ftp_data.cc +++ b/src/service_inspectors/ftp_telnet/ftp_data.cc @@ -23,6 +23,9 @@ #include "ftp_data.h" +#include +#include + #include "detection/detection_engine.h" #include "file_api/file_flows.h" #include "file_api/file_service.h" @@ -46,6 +49,8 @@ using namespace snort; "FTP data channel handler" static const char* const fd_svc_name = "ftp-data"; +static std::shared_ptr shared_fd_svc_name = + std::make_shared(fd_svc_name); static THREAD_LOCAL ProfileStats ftpdataPerfStats; static THREAD_LOCAL SimpleStats fdstats; @@ -223,15 +228,15 @@ FtpDataFlowData::~FtpDataFlowData() void FtpDataFlowData::handle_expected(Packet* p) { - if (!p->flow->service) + if (!p->flow->has_service()) { - p->flow->set_service(p, fd_svc_name); + p->flow->set_service(p, shared_fd_svc_name); FtpDataFlowData* fd = (FtpDataFlowData*)p->flow->get_flow_data(FtpDataFlowData::inspector_id); if (fd and fd->in_tls) { - OpportunisticTlsEvent evt(p, fd_svc_name); + OpportunisticTlsEvent evt(p, shared_fd_svc_name); DataBus::publish(OPPORTUNISTIC_TLS_EVENT, evt, p->flow); } else diff --git a/src/service_inspectors/http_inspect/http_inspect.cc b/src/service_inspectors/http_inspect/http_inspect.cc index 51c78e091..b408e2719 100755 --- a/src/service_inspectors/http_inspect/http_inspect.cc +++ b/src/service_inspectors/http_inspect/http_inspect.cc @@ -688,7 +688,7 @@ void HttpInspect::clear(Packet* p) if (session_data->cutover_on_clear) { Flow* flow = p->flow; - flow->set_service(p, nullptr); + flow->clear_service(p); flow->free_flow_data(HttpFlowData::inspector_id); } } diff --git a/src/service_inspectors/pop/pop.cc b/src/service_inspectors/pop/pop.cc index efea0bf91..0f4ff54bb 100644 --- a/src/service_inspectors/pop/pop.cc +++ b/src/service_inspectors/pop/pop.cc @@ -725,7 +725,7 @@ bool Pop::get_buf(InspectionBuffer::Type ibt, Packet* p, InspectionBuffer& b) return false; const BufferData& vba_buf = pop_ssn->mime_ssn->get_vba_inspect_buf(); - + if (vba_buf.data_ptr() && vba_buf.length()) { b.data = vba_buf.data_ptr(); diff --git a/src/service_inspectors/ssl/ssl_inspector.cc b/src/service_inspectors/ssl/ssl_inspector.cc index 1772a8dd1..e4701db51 100644 --- a/src/service_inspectors/ssl/ssl_inspector.cc +++ b/src/service_inspectors/ssl/ssl_inspector.cc @@ -25,6 +25,9 @@ #include "ssl_inspector.h" +#include +#include + #include "detection/detect.h" #include "detection/detection_engine.h" #include "events/event_queue.h" @@ -400,6 +403,7 @@ static void snort_ssl(SSL_PROTO_CONF* config, Packet* p) // class stuff //------------------------------------------------------------------------- static const char* s_name = "ssl"; +static std::shared_ptr shared_s_name = std::make_shared(s_name); class Ssl : public Inspector { @@ -448,7 +452,7 @@ public: pkt->flow->flags.trigger_finalize_event = fd->finalize_info.orig_flag; fd->finalize_info.switch_in = false; pkt->flow->set_proxied(); - pkt->flow->set_service(const_cast(pkt), s_name); + pkt->flow->set_service(const_cast(pkt), shared_s_name); } } }; diff --git a/src/service_inspectors/wizard/curses.cc b/src/service_inspectors/wizard/curses.cc index c75696623..968934759 100644 --- a/src/service_inspectors/wizard/curses.cc +++ b/src/service_inspectors/wizard/curses.cc @@ -393,11 +393,11 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke // map between service and curse details static vector curse_map { - // name service alg is_tcp - { "dce_udp", "dcerpc", dce_udp_curse, false }, - { "dce_tcp", "dcerpc", dce_tcp_curse, true }, - { "dce_smb", "netbios-ssn", dce_smb_curse, true }, - { "sslv2" , "ssl", ssl_v2_curse , true } + // name service alg is_tcp + { "dce_udp", make_shared("dcerpc") , dce_udp_curse, false }, + { "dce_tcp", make_shared("dcerpc") , dce_tcp_curse, true }, + { "dce_smb", make_shared("netbios-ssn"), dce_smb_curse, true }, + { "sslv2" , make_shared("ssl") , ssl_v2_curse , true } }; bool CurseBook::add_curse(const char* key) diff --git a/src/service_inspectors/wizard/curses.h b/src/service_inspectors/wizard/curses.h index 6a23903bd..8b2d9b682 100644 --- a/src/service_inspectors/wizard/curses.h +++ b/src/service_inspectors/wizard/curses.h @@ -21,6 +21,7 @@ #define CURSES_H #include +#include #include #include @@ -86,7 +87,7 @@ typedef bool (* curse_alg)(const uint8_t* data, unsigned len, CurseTracker*); struct CurseDetails { std::string name; - std::string service; + std::shared_ptr service; curse_alg alg; bool is_tcp; }; diff --git a/src/service_inspectors/wizard/hexes.cc b/src/service_inspectors/wizard/hexes.cc index 53e2ca1bf..9b2880fb7 100644 --- a/src/service_inspectors/wizard/hexes.cc +++ b/src/service_inspectors/wizard/hexes.cc @@ -91,7 +91,7 @@ void HexBook::add_spell( ++i; } p->key = key; - p->value = val; + p->value = make_shared(val); } bool HexBook::add_spell(const char* key, const char*& val) @@ -124,7 +124,7 @@ bool HexBook::add_spell(const char* key, const char*& val) } if ( p->key == key ) { - val = p->value.c_str(); + val = p->value->c_str(); return false; } @@ -158,7 +158,7 @@ const MagicPage* HexBook::find_spell( if ( const MagicPage* q = find_spell(s, n, p->any, i+1) ) return q; } - return p->value.empty() ? nullptr : p; + return p->value.use_count() ? p : nullptr; } return p; } diff --git a/src/service_inspectors/wizard/magic.cc b/src/service_inspectors/wizard/magic.cc index 9e6c4fd3f..a6f2a3f84 100644 --- a/src/service_inspectors/wizard/magic.cc +++ b/src/service_inspectors/wizard/magic.cc @@ -42,14 +42,14 @@ MagicPage::~MagicPage() delete any; } -const char* MagicBook::find_spell(const uint8_t* data, unsigned len, const MagicPage*& p) const +std::shared_ptr MagicBook::find_spell(const uint8_t* data, unsigned len, + const MagicPage*& p) const { assert(p); p = find_spell(data, len, p, 0); - - if ( p && !p->value.empty() ) - return p->value.c_str(); + if ( p && p->value.use_count() ) + return p->value; return nullptr; } diff --git a/src/service_inspectors/wizard/magic.h b/src/service_inspectors/wizard/magic.h index 14238cca5..f35380a13 100644 --- a/src/service_inspectors/wizard/magic.h +++ b/src/service_inspectors/wizard/magic.h @@ -20,6 +20,7 @@ #ifndef MAGIC_H #define MAGIC_H +#include #include #include @@ -28,7 +29,7 @@ class MagicBook; struct MagicPage { std::string key; - std::string value; + std::shared_ptr value; MagicPage* next[256]; MagicPage* any; @@ -52,7 +53,8 @@ public: MagicBook& operator=(const MagicBook&) = delete; virtual bool add_spell(const char* key, const char*& val) = 0; - virtual const char* find_spell(const uint8_t*, unsigned len, const MagicPage*&) const; + virtual std::shared_ptr find_spell(const uint8_t*, unsigned len, + const MagicPage*&) const; const MagicPage* page1() const { return root; } diff --git a/src/service_inspectors/wizard/spells.cc b/src/service_inspectors/wizard/spells.cc index b58a23e0b..ce6ba6f6d 100644 --- a/src/service_inspectors/wizard/spells.cc +++ b/src/service_inspectors/wizard/spells.cc @@ -84,7 +84,7 @@ void SpellBook::add_spell( ++i; } p->key = key; - p->value = val; + p->value = make_shared(val); } bool SpellBook::add_spell(const char* key, const char*& val) @@ -118,7 +118,7 @@ bool SpellBook::add_spell(const char* key, const char*& val) } if ( p->key == key ) { - val = p->value.c_str(); + val = p->value->c_str(); return false; } @@ -162,15 +162,15 @@ const MagicPage* SpellBook::find_spell( } // If no match but has glob, continue lookup from glob - if ( p->value.empty() && glob ) - { + if ( !p->value.use_count() && glob ) + { p = glob; glob = nullptr; - + return find_spell(s, n, p, i); } - return p->value.empty() ? nullptr : p; + return p->value.use_count() ? p : nullptr; } return p; } diff --git a/src/service_inspectors/wizard/wizard.cc b/src/service_inspectors/wizard/wizard.cc index c7f191799..7d20b9c9d 100644 --- a/src/service_inspectors/wizard/wizard.cc +++ b/src/service_inspectors/wizard/wizard.cc @@ -195,7 +195,7 @@ StreamSplitter::Status MagicSplitter::scan( if ( wizard->cast_spell(wand, pkt->flow, data, len, wizard_processed_bytes) ) { trace_logf(wizard_trace, pkt, "%s streaming search found service %s\n", - to_server() ? "c2s" : "s2c", pkt->flow->service); + to_server() ? "c2s" : "s2c", pkt->flow->service->c_str()); count_hit(pkt->flow); wizard_processed_bytes = 0; return STOP; @@ -225,7 +225,7 @@ StreamSplitter::Status MagicSplitter::scan( // delayed. Because AppId depends on wizard only for SSH detection and SSH inspector can be // attached very early, event is raised here after first scan. In the future, wizard should be // enhanced to abort sooner if it can't detect service. - if (!pkt->flow->service && !pkt->flow->flags.svc_event_generated) + if (!pkt->flow->has_service() && !pkt->flow->flags.svc_event_generated) { DataBus::publish(FLOW_NO_SERVICE_EVENT, pkt); pkt->flow->flags.svc_event_generated = true; @@ -309,7 +309,7 @@ void Wizard::eval(Packet* p) if ( cast_spell(wand, p->flow, p->data, p->dsize, udp_processed_bytes) ) { trace_logf(wizard_trace, p, "%s datagram search found service %s\n", - c2s ? "c2s" : "s2c", p->flow->service); + c2s ? "c2s" : "s2c", p->flow->service->c_str()); ++tstats.udp_hits; } else @@ -328,8 +328,12 @@ StreamSplitter* Wizard::get_splitter(bool c2s) bool Wizard::spellbind( const MagicPage*& m, Flow* f, const uint8_t* data, unsigned len) { - f->service = m->book.find_spell(data, len, m); - return ( f->service != nullptr ); + std::shared_ptr p_shared = m->book.find_spell(data, len, m); + if (p_shared.use_count()) + f->service = p_shared; + else + f->service.reset(); + return f->has_service(); } bool Wizard::cursebind(const vector& curse_tracker, Flow* f, @@ -339,8 +343,11 @@ bool Wizard::cursebind(const vector& curse_tracker, Flow* f { if (cst.curse->alg(data, len, cst.tracker)) { - f->service = cst.curse->service.c_str(); - if ( f->service != nullptr ) + if (cst.curse->service.use_count()) + f->service = cst.curse->service; + else + f->service.reset(); + if ( f->has_service() ) return true; } } @@ -366,9 +373,9 @@ bool Wizard::cast_spell( if (cursebind(w.curse_tracker, f, data, curse_len)) return true; - // If we reach max value of wizard_processed_bytes, + // If we reach max value of wizard_processed_bytes, // but not assign any inspector - raise tcp_miss and stop - if ( !f->service && wizard_processed_bytes >= max_search_depth ) + if ( !f->has_service() && wizard_processed_bytes >= max_search_depth ) { w.spell = nullptr; w.hex = nullptr; diff --git a/src/stream/base/stream_base.cc b/src/stream/base/stream_base.cc index b6e7f5018..963ac2ea3 100644 --- a/src/stream/base/stream_base.cc +++ b/src/stream/base/stream_base.cc @@ -180,7 +180,7 @@ StreamBase::StreamBase(const StreamModuleConfig* c) { config = *c; } void StreamBase::tear_down(SnortConfig* sc) -{ sc->register_reload_resource_tuner(new StreamUnloadReloadResourceManager); } +{ sc->register_reload_handler(new StreamUnloadReloadResourceManager); } void StreamBase::tinit() { diff --git a/src/stream/base/stream_module.cc b/src/stream/base/stream_module.cc index adb7c52ea..fd343616c 100644 --- a/src/stream/base/stream_module.cc +++ b/src/stream/base/stream_module.cc @@ -221,11 +221,11 @@ bool StreamModule::end(const char* fqn, int, SnortConfig* sc) { StreamReloadResourceManager* reload_resource_manager = new StreamReloadResourceManager; if (reload_resource_manager->initialize(config)) - sc->register_reload_resource_tuner(reload_resource_manager); + sc->register_reload_handler(reload_resource_manager); else delete reload_resource_manager; - sc->register_reload_resource_tuner(new HPQReloadTuner(config.held_packet_timeout)); + sc->register_reload_handler(new HPQReloadTuner(config.held_packet_timeout)); } return true; diff --git a/src/stream/base/stream_module.h b/src/stream/base/stream_module.h index 7b827e2e5..2023dd11f 100644 --- a/src/stream/base/stream_module.h +++ b/src/stream/base/stream_module.h @@ -21,11 +21,11 @@ #ifndef STREAM_MODULE_H #define STREAM_MODULE_H -#include "main/analyzer.h" -#include "main/snort_config.h" #include "flow/flow_config.h" #include "flow/flow_control.h" #include "framework/module.h" +#include "main/analyzer.h" +#include "main/reload_tuner.h" namespace snort { diff --git a/src/stream/tcp/stream_tcp.cc b/src/stream/tcp/stream_tcp.cc index e0bb5caba..651f89190 100644 --- a/src/stream/tcp/stream_tcp.cc +++ b/src/stream/tcp/stream_tcp.cc @@ -73,16 +73,10 @@ bool StreamTcp::configure(SnortConfig* sc) } void StreamTcp::tinit() -{ - TcpHAManager::tinit(); - TcpSession::sinit(); -} +{ TcpHAManager::tinit(); } void StreamTcp::tterm() -{ - TcpHAManager::tterm(); - TcpSession::sterm(); -} +{ TcpHAManager::tterm(); } NORETURN_ASSERT void StreamTcp::eval(Packet*) { @@ -129,6 +123,12 @@ static void stream_tcp_pterm() TcpNormalizerFactory::term(); } +static void stream_tcp_tinit() +{ TcpSession::sinit(); } + +static void stream_tcp_tterm() +{ TcpSession::sterm(); } + static Session* tcp_ssn(Flow* lws) { return new TcpSession(lws); } @@ -152,8 +152,8 @@ static const InspectApi tcp_api = nullptr, // service stream_tcp_pinit, // pinit stream_tcp_pterm, // pterm - nullptr, // tinit, - nullptr, // tterm, + stream_tcp_tinit, // tinit, + stream_tcp_tterm, // tterm, tcp_ctor, tcp_dtor, tcp_ssn, diff --git a/src/stream/tcp/tcp_reassembler.cc b/src/stream/tcp/tcp_reassembler.cc index 92723ee50..5edef6175 100644 --- a/src/stream/tcp/tcp_reassembler.cc +++ b/src/stream/tcp/tcp_reassembler.cc @@ -783,8 +783,8 @@ static Packet* get_packet(Flow* flow, uint32_t flags, bool c2s) p->ip_proto_next = (IpProtocol)flow->ip_proto; + set_inspection_policy(flow->inspection_policy_id); const SnortConfig* sc = SnortConfig::get_conf(); - set_inspection_policy(sc, flow->inspection_policy_id); set_ips_policy(sc, flow->ips_policy_id); return p; diff --git a/src/target_based/host_attributes.cc b/src/target_based/host_attributes.cc index 763da5430..4b30c69f6 100644 --- a/src/target_based/host_attributes.cc +++ b/src/target_based/host_attributes.cc @@ -26,6 +26,7 @@ #include "host_attributes.h" #include "hash/lru_cache_shared.h" +#include "main/reload_tuner.h" #include "main/shell.h" #include "main/snort.h" #include "main/snort_config.h" @@ -171,7 +172,7 @@ void HostAttributesManager::activate(SnortConfig* sc) next_cache = nullptr; if( active_cache != old_cache and Snort::is_reloading() ) - sc->register_reload_resource_tuner(new HostAttributesReloadTuner); + sc->register_reload_handler(new HostAttributesReloadTuner); } void HostAttributesManager::initialize() diff --git a/src/target_based/snort_protocols.cc b/src/target_based/snort_protocols.cc index e453d003f..b37d0a2fa 100644 --- a/src/target_based/snort_protocols.cc +++ b/src/target_based/snort_protocols.cc @@ -39,19 +39,25 @@ SnortProtocolId ProtocolReference::get_count() const { return protocol_number; } const char* ProtocolReference::get_name(SnortProtocolId id) const +{ + std::shared_ptr shared_name = get_shared_name(id); + return shared_name->c_str(); +} + +std::shared_ptr ProtocolReference::get_shared_name(SnortProtocolId id) const { if ( id >= id_map.size() ) id = 0; - return id_map[id].c_str(); + return id_map[id]; } struct Compare { bool operator()(SnortProtocolId a, SnortProtocolId b) - { return map[a] < map[b]; } + { return map[a]->c_str() < map[b]->c_str(); } - vector& map; + vector>& map; }; const char* ProtocolReference::get_name_sorted(SnortProtocolId id) @@ -67,7 +73,7 @@ const char* ProtocolReference::get_name_sorted(SnortProtocolId id) if ( id >= ind_map.size() ) return nullptr; - return id_map[ind_map[id]].c_str(); + return id_map[ind_map[id]]->c_str(); } SnortProtocolId ProtocolReference::add(const char* protocol) @@ -80,7 +86,7 @@ SnortProtocolId ProtocolReference::add(const char* protocol) return protocol_ref->second; SnortProtocolId snort_protocol_id = protocol_number++; - id_map.emplace_back(protocol); + id_map.emplace_back(make_shared(protocol)); ref_table[protocol] = snort_protocol_id; return snort_protocol_id; diff --git a/src/target_based/snort_protocols.h b/src/target_based/snort_protocols.h index 0fbca27d9..6281cd1b5 100644 --- a/src/target_based/snort_protocols.h +++ b/src/target_based/snort_protocols.h @@ -22,6 +22,7 @@ #ifndef SNORT_PROTOCOLS_H #define SNORT_PROTOCOLS_H +#include #include #include #include @@ -73,6 +74,7 @@ public: SnortProtocolId get_count() const; const char* get_name(SnortProtocolId id) const; + std::shared_ptr get_shared_name(SnortProtocolId id) const; const char* get_name_sorted(SnortProtocolId id); SnortProtocolId add(const char* protocol); @@ -81,7 +83,7 @@ public: bool operator()(SnortProtocolId a, SnortProtocolId b); private: - std::vector id_map; + std::vector> id_map; std::vector ind_map; std::unordered_map ref_table;