From: Ron Dempster (rdempste) Date: Sat, 12 Aug 2023 00:48:18 +0000 (+0000) Subject: Pull request #3956: managers: fix get_inspector to use the passed in snort config... X-Git-Tag: 3.1.69.0~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=30c4b6af333b65079716835f2bda6132e9e366a2;p=thirdparty%2Fsnort3.git Pull request #3956: managers: fix get_inspector to use the passed in snort config for context and inspection inspectors Merge in SNORT/snort3 from ~RDEMPSTE/snort3:fqdn to master Squashed commit of the following: commit 8394704aec2431ef1d070cbec8109075f2bed399 Author: Ron Dempster (rdempste) Date: Tue Jul 25 10:15:45 2023 -0400 managers: fix get_inspector to use the passed in snort config for context and inspection inspectors --- diff --git a/src/main/modules.cc b/src/main/modules.cc index 09d93e138..abb5633b2 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -871,7 +871,6 @@ class InspectionModule : public Module public: InspectionModule() : Module("inspection", inspection_help, inspection_params) { } bool set(const char*, Value&, SnortConfig*) override; - bool end(const char*, int, SnortConfig*) override; Usage get_usage() const override { return INSPECT; } @@ -916,15 +915,6 @@ bool InspectionModule::set(const char*, Value& v, SnortConfig* sc) return true; } -bool InspectionModule::end(const char*, int, SnortConfig*) -{ - InspectionPolicy* p = get_inspection_policy(); - NetworkPolicy* np = get_network_parse_policy(); - assert(np); - np->set_user_inspection(p); - return true; -} - //------------------------------------------------------------------------- // Ips policy module //------------------------------------------------------------------------- diff --git a/src/main/policy.cc b/src/main/policy.cc index 7e1ab9a27..db1b11c42 100644 --- a/src/main/policy.cc +++ b/src/main/policy.cc @@ -123,7 +123,13 @@ FilePolicy* NetworkPolicy::get_file_policy() const void NetworkPolicy::add_file_policy_rule(FileRule& file_rule) { file_policy->add_file_id(file_rule); } -InspectionPolicy* NetworkPolicy::get_user_inspection_policy(unsigned user_id) +void NetworkPolicy::setup_inspection_policies() +{ + std::for_each(inspection_policy.begin(), inspection_policy.end(), + [this](InspectionPolicy* ip){ set_user_inspection(ip); }); +} + +InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t user_id) const { auto it = user_inspection.find(user_id); return it == user_inspection.end() ? nullptr : it->second; @@ -386,9 +392,10 @@ NetworkPolicy* PolicyMap::get_user_network(uint64_t user_id) const bool PolicyMap::set_user_network(NetworkPolicy* p) { NetworkPolicy* current_np = get_user_network(p->user_policy_id); - if (current_np && p != current_np) - return false; + if (current_np) + return p == current_np; user_network[p->user_policy_id] = p; + p->setup_inspection_policies(); return true; } @@ -428,7 +435,7 @@ void set_inspection_policy(InspectionPolicy* p) void set_ips_policy(IpsPolicy* p) { s_detection_policy = p; } -InspectionPolicy* get_user_inspection_policy(unsigned policy_id) +InspectionPolicy* get_user_inspection_policy(uint64_t policy_id) { NetworkPolicy* np = get_network_policy(); assert(np); diff --git a/src/main/policy.h b/src/main/policy.h index 5566c0a11..6a2e7bfb7 100644 --- a/src/main/policy.h +++ b/src/main/policy.h @@ -95,7 +95,7 @@ SO_PUBLIC void set_ips_policy(IpsPolicy*); SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*); // Based on currently set network policy -SO_PUBLIC InspectionPolicy* get_user_inspection_policy(unsigned policy_id); +SO_PUBLIC InspectionPolicy* get_user_inspection_policy(uint64_t policy_id); SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0); // Based on currently set network policy @@ -183,7 +183,8 @@ public: { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; } unsigned inspection_policy_count() { return inspection_policy.size(); } - InspectionPolicy* get_user_inspection_policy(unsigned user_id); + void setup_inspection_policies(); + InspectionPolicy* get_user_inspection_policy(uint64_t user_id) const; void set_user_inspection(InspectionPolicy* p) { user_inspection[p->user_policy_id] = p; } diff --git a/src/managers/inspector_manager.cc b/src/managers/inspector_manager.cc index eec5a09b6..9c27cac00 100644 --- a/src/managers/inspector_manager.cc +++ b/src/managers/inspector_manager.cc @@ -1220,13 +1220,15 @@ Inspector* InspectorManager::get_file_inspector(const SnortConfig* sc) } // FIXIT-P cache get_inspector() returns or provide indexed lookup -Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* sc) +Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* snort_config) { InspectionPolicy* pi; NetworkPolicy* ni; + const SnortConfig* sc = snort_config; if ( !sc ) sc = SnortConfig::get_conf(); + assert(sc); if ( dflt_only ) { ni = get_default_network_policy(sc); @@ -1235,7 +1237,31 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons else { pi = get_inspection_policy(); + // During reload, get_network_policy will return the network policy from the new snort config + // for a given tenant ni = get_network_policy(); + if (!snort_config) + { + // If no snort config is passed in, it means that this is either a normally running system with + // the correct network policy set or that get_inspector is being called from Inspector::configure + // and it is expecting the inspector from the running configuration and not the new snort config + if (ni) + { + PolicyMap* pm = sc->policy_map; + NetworkPolicy* np = pm->get_user_network(ni->user_policy_id); + if (np) + { + // If network policy is correct, then no need to change the inspection policy + if (np != ni && pi) + pi = np->get_user_inspection_policy(pi->user_policy_id); + ni = np; + } + else + pi = nullptr; + } + else + pi = nullptr; + } } if ( pi ) @@ -1268,15 +1294,11 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons return nullptr; } -Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, - InspectorType type, const SnortConfig* sc) +Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, InspectorType type) { - if ( !sc ) - { - sc = SnortConfig::get_conf(); - if (!sc) - return nullptr; - } + const SnortConfig* sc = SnortConfig::get_conf(); + if (!sc) + return nullptr; if (Module::GLOBAL == usage && IT_FILE == type) { @@ -1302,6 +1324,10 @@ Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, else if (Module::CONTEXT == usage) { NetworkPolicy* np = get_network_policy(); + if (!np) + return nullptr; + PolicyMap* pm = sc->policy_map; + np = pm->get_user_network(np->user_policy_id); if (!np) return nullptr; TrafficPolicy* il = np->traffic_policy; @@ -1311,9 +1337,23 @@ Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, } else { + NetworkPolicy* orig_np = get_network_policy(); + if (!orig_np) + return nullptr; + PolicyMap* pm = sc->policy_map; + NetworkPolicy* np = pm->get_user_network(orig_np->user_policy_id); + if (!np) + return nullptr; InspectionPolicy* ip = get_inspection_policy(); if (!ip) return nullptr; + // If network policy is correct, then no need to change the inspection policy + if (np != orig_np) + { + ip = np->get_user_inspection_policy(ip->user_policy_id); + if (!ip) + return nullptr; + } FrameworkPolicy* il = ip->framework_policy; assert(il); PHInstance* p = il->get_instance_by_type(key, type); diff --git a/src/managers/inspector_manager.h b/src/managers/inspector_manager.h index c99da4d9b..a13ab027a 100644 --- a/src/managers/inspector_manager.h +++ b/src/managers/inspector_manager.h @@ -79,10 +79,13 @@ public: static InspectSsnFunc get_session(uint16_t proto); SO_PUBLIC static Inspector* get_file_inspector(const SnortConfig* = nullptr); - SO_PUBLIC static Inspector* get_inspector( - const char* key, bool dflt_only = false, const SnortConfig* = nullptr); - SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType, - const SnortConfig* = nullptr); + + // This assumes that, in a multi-tenant scenario, this is called with the correct network and inspection + // policies are set correctly + SO_PUBLIC static Inspector* get_inspector(const char* key, bool dflt_only = false, const SnortConfig* = nullptr); + + // This cannot be called in or before the inspector configure phase for a new snort config during reload + SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType); static Inspector* get_service_inspector_by_service(const char*); static Inspector* get_service_inspector_by_id(const SnortProtocolId); diff --git a/src/managers/test/get_inspector_stubs.h b/src/managers/test/get_inspector_stubs.h index 86d07df6d..105c12927 100644 --- a/src/managers/test/get_inspector_stubs.h +++ b/src/managers/test/get_inspector_stubs.h @@ -33,7 +33,6 @@ THREAD_LOCAL const snort::Trace* snort_trace = nullptr; std::shared_ptr PolicyMap::get_policies(Shell*) { return nullptr; } -NetworkPolicy* PolicyMap::get_user_network(uint64_t) const { return nullptr; } void InspectionPolicy::configure() { } void BinderModule::add(const char*, const char*) { } void BinderModule::add(unsigned, const char*) { } diff --git a/src/managers/test/get_inspector_test.cc b/src/managers/test/get_inspector_test.cc index c161e7eee..f01152231 100644 --- a/src/managers/test/get_inspector_test.cc +++ b/src/managers/test/get_inspector_test.cc @@ -35,8 +35,12 @@ bool Inspector::is_inactive() { return true; } NetworkPolicy* snort::get_network_policy() { return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); } +NetworkPolicy* PolicyMap::get_user_network(uint64_t) const +{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); } InspectionPolicy* snort::get_inspection_policy() { return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); } +InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t) const +{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); } InspectionPolicy::InspectionPolicy(PolicyId) { InspectorManager::new_policy(this, nullptr); } diff --git a/src/network_inspectors/binder/binder.cc b/src/network_inspectors/binder/binder.cc index ec9f82986..f09395a98 100644 --- a/src/network_inspectors/binder/binder.cc +++ b/src/network_inspectors/binder/binder.cc @@ -629,7 +629,7 @@ bool Binder::configure(SnortConfig* sc) default: name = nullptr; break; } if (name) - default_ssn_inspectors[proto] = InspectorManager::get_inspector(name); + default_ssn_inspectors[proto] = InspectorManager::get_inspector(name, false, sc); } DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::PKT_WITHOUT_FLOW, new NonFlowPacketHandler()); diff --git a/src/service_inspectors/ftp_telnet/ft_main.cc b/src/service_inspectors/ftp_telnet/ft_main.cc index d553aa984..eade90a63 100644 --- a/src/service_inspectors/ftp_telnet/ft_main.cc +++ b/src/service_inspectors/ftp_telnet/ft_main.cc @@ -176,13 +176,13 @@ int FTPCheckConfigs(SnortConfig* sc, void* pData) return rval; // Verify that FTP client and FTP data inspectors are initialized. - if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false)) + if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false, sc)) { ParseError("ftp_server requires that %s also be configured.", FTP_CLIENT_NAME); return -1; } - if(!InspectorManager::get_inspector(FTP_DATA_NAME, false)) + if(!InspectorManager::get_inspector(FTP_DATA_NAME, false, sc)) { ParseError("ftp_server requires that %s also be configured.", FTP_DATA_NAME); return -1; diff --git a/src/service_inspectors/gtp/gtp_inspect.cc b/src/service_inspectors/gtp/gtp_inspect.cc index 8c50da407..b5f3aef9e 100644 --- a/src/service_inspectors/gtp/gtp_inspect.cc +++ b/src/service_inspectors/gtp/gtp_inspect.cc @@ -153,9 +153,9 @@ int GtpInspect::get_message_type(int version, const char* name) return -1; } -int get_message_type(int version, const char* name) +int get_message_type(int version, const char* name, snort::SnortConfig* sc) { - GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME); + GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc); if ( !ins ) return -1; @@ -175,9 +175,9 @@ int GtpInspect::get_info_type(int version, const char* name) return -1; } -int get_info_type(int version, const char* name) +int get_info_type(int version, const char* name, SnortConfig* sc) { - GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME); + GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc); if ( !ins ) return -1; diff --git a/src/service_inspectors/gtp/gtp_inspect.h b/src/service_inspectors/gtp/gtp_inspect.h index a0b9ca6f1..3463fc8a7 100644 --- a/src/service_inspectors/gtp/gtp_inspect.h +++ b/src/service_inspectors/gtp/gtp_inspect.h @@ -48,8 +48,13 @@ public: GTP_Roptions ropts; }; -int get_message_type(int version, const char* name); -int get_info_type(int version, const char* name); +namespace snort +{ +struct SnortConfig; +} + +int get_message_type(int version, const char* name, snort::SnortConfig*); +int get_info_type(int version, const char* name, snort::SnortConfig*); struct GTP_IEData* get_infos(); diff --git a/src/service_inspectors/gtp/ips_gtp_info.cc b/src/service_inspectors/gtp/ips_gtp_info.cc index 2264acbe9..aa38d5099 100644 --- a/src/service_inspectors/gtp/ips_gtp_info.cc +++ b/src/service_inspectors/gtp/ips_gtp_info.cc @@ -151,7 +151,7 @@ public: bool set(const char*, Value&, SnortConfig*) override; bool set_types(long); - bool set_types(const char*); + bool set_types(const char*, SnortConfig*); ProfileStats* get_profile() const override { return >p_info_prof; } @@ -174,13 +174,13 @@ bool GtpInfoModule::set_types(long t) return true; } -bool GtpInfoModule::set_types(const char* name) +bool GtpInfoModule::set_types(const char* name, SnortConfig* sc) { bool ok = false; for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v ) { - int t = get_info_type(v, name); + int t = get_info_type(v, name, sc); if ( t < 0 ) continue; @@ -191,7 +191,7 @@ bool GtpInfoModule::set_types(const char* name) return ok; } -bool GtpInfoModule::set(const char*, Value& v, SnortConfig*) +bool GtpInfoModule::set(const char*, Value& v, SnortConfig* sc) { assert(v.is("~")); long n; @@ -199,7 +199,7 @@ bool GtpInfoModule::set(const char*, Value& v, SnortConfig*) if ( v.strtol(n) ) return set_types(n); - return set_types(v.get_string()); + return set_types(v.get_string(), sc); } //------------------------------------------------------------------------- diff --git a/src/service_inspectors/gtp/ips_gtp_type.cc b/src/service_inspectors/gtp/ips_gtp_type.cc index 82a53df4e..5ebea1468 100644 --- a/src/service_inspectors/gtp/ips_gtp_type.cc +++ b/src/service_inspectors/gtp/ips_gtp_type.cc @@ -138,7 +138,7 @@ public: bool set(const char*, Value&, SnortConfig*) override; bool set_types(long); - bool set_types(const char*); + bool set_types(const char*, SnortConfig*); ProfileStats* get_profile() const override { return >p_type_prof; } @@ -169,13 +169,13 @@ bool GtpTypeModule::set_types(long t) return true; } -bool GtpTypeModule::set_types(const char* name) +bool GtpTypeModule::set_types(const char* name, SnortConfig* sc) { bool ok = false; for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v ) { - int t = get_message_type(v, name); + int t = get_message_type(v, name, sc); if ( t < 0 ) continue; @@ -186,7 +186,7 @@ bool GtpTypeModule::set_types(const char* name) return ok; } -bool GtpTypeModule::set(const char*, Value& v, SnortConfig*) +bool GtpTypeModule::set(const char*, Value& v, SnortConfig* sc) { assert(v.is("~")); v.set_first_token(); @@ -210,7 +210,7 @@ bool GtpTypeModule::set(const char*, Value& v, SnortConfig*) if ( !set_types(n) ) return false; } - else if ( !set_types(tok.c_str()) ) + else if ( !set_types(tok.c_str(), sc) ) return false; } return true;