]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3956: managers: fix get_inspector to use the passed in snort config...
authorRon Dempster (rdempste) <rdempste@cisco.com>
Sat, 12 Aug 2023 00:48:18 +0000 (00:48 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Sat, 12 Aug 2023 00:48:18 +0000 (00:48 +0000)
Merge in SNORT/snort3 from ~RDEMPSTE/snort3:fqdn to master

Squashed commit of the following:

commit 8394704aec2431ef1d070cbec8109075f2bed399
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Jul 25 10:15:45 2023 -0400

    managers: fix get_inspector to use the passed in snort config for context and inspection inspectors

13 files changed:
src/main/modules.cc
src/main/policy.cc
src/main/policy.h
src/managers/inspector_manager.cc
src/managers/inspector_manager.h
src/managers/test/get_inspector_stubs.h
src/managers/test/get_inspector_test.cc
src/network_inspectors/binder/binder.cc
src/service_inspectors/ftp_telnet/ft_main.cc
src/service_inspectors/gtp/gtp_inspect.cc
src/service_inspectors/gtp/gtp_inspect.h
src/service_inspectors/gtp/ips_gtp_info.cc
src/service_inspectors/gtp/ips_gtp_type.cc

index 09d93e1387507086039a0197ad6542ee360e0383..abb5633b2557de866ee58fbdb05560ffc499fdb2 100644 (file)
@@ -871,7 +871,6 @@ class InspectionModule : public Module
 public:
     InspectionModule() : Module("inspection", inspection_help, inspection_params) { }
     bool set(const char*, Value&, SnortConfig*) override;
-    bool end(const char*, int, SnortConfig*) override;
 
     Usage get_usage() const override
     { return INSPECT; }
@@ -916,15 +915,6 @@ bool InspectionModule::set(const char*, Value& v, SnortConfig* sc)
     return true;
 }
 
-bool InspectionModule::end(const char*, int, SnortConfig*)
-{
-    InspectionPolicy* p = get_inspection_policy();
-    NetworkPolicy* np = get_network_parse_policy();
-    assert(np);
-    np->set_user_inspection(p);
-    return true;
-}
-
 //-------------------------------------------------------------------------
 // Ips policy module
 //-------------------------------------------------------------------------
index 7e1ab9a274927fe70e673c1ac77fa757d6f1c448..db1b11c423eb447765830b75f1e2293d2a40abc5 100644 (file)
@@ -123,7 +123,13 @@ FilePolicy* NetworkPolicy::get_file_policy() const
 void NetworkPolicy::add_file_policy_rule(FileRule& file_rule)
 { file_policy->add_file_id(file_rule); }
 
-InspectionPolicy* NetworkPolicy::get_user_inspection_policy(unsigned user_id)
+void NetworkPolicy::setup_inspection_policies()
+{
+    std::for_each(inspection_policy.begin(), inspection_policy.end(),
+        [this](InspectionPolicy* ip){ set_user_inspection(ip); });
+}
+
+InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t user_id) const
 {
     auto it = user_inspection.find(user_id);
     return it == user_inspection.end() ? nullptr : it->second;
@@ -386,9 +392,10 @@ NetworkPolicy* PolicyMap::get_user_network(uint64_t user_id) const
 bool PolicyMap::set_user_network(NetworkPolicy* p)
 {
     NetworkPolicy* current_np = get_user_network(p->user_policy_id);
-    if (current_np && p != current_np)
-        return false;
+    if (current_np)
+        return p == current_np;
     user_network[p->user_policy_id] = p;
+    p->setup_inspection_policies();
     return true;
 }
 
@@ -428,7 +435,7 @@ void set_inspection_policy(InspectionPolicy* p)
 void set_ips_policy(IpsPolicy* p)
 { s_detection_policy = p; }
 
-InspectionPolicy* get_user_inspection_policy(unsigned policy_id)
+InspectionPolicy* get_user_inspection_policy(uint64_t policy_id)
 {
     NetworkPolicy* np = get_network_policy();
     assert(np);
index 5566c0a11fbcc87496513d30e3cefc5413d9e05d..6a2e7bfb77b2f07eccf41b2edb808a7e07f87e85 100644 (file)
@@ -95,7 +95,7 @@ SO_PUBLIC void set_ips_policy(IpsPolicy*);
 
 SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*);
 // Based on currently set network policy
-SO_PUBLIC InspectionPolicy* get_user_inspection_policy(unsigned policy_id);
+SO_PUBLIC InspectionPolicy* get_user_inspection_policy(uint64_t policy_id);
 
 SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0);
 // Based on currently set network policy
@@ -183,7 +183,8 @@ public:
     { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; }
     unsigned inspection_policy_count()
     { return inspection_policy.size(); }
-    InspectionPolicy* get_user_inspection_policy(unsigned user_id);
+    void setup_inspection_policies();
+    InspectionPolicy* get_user_inspection_policy(uint64_t user_id) const;
     void set_user_inspection(InspectionPolicy* p)
     { user_inspection[p->user_policy_id] = p; }
 
index eec5a09b6a51fbcac840e26c12efe699c5485c99..9c27cac0046d3688bb10b33ea3299cb468a12ba6 100644 (file)
@@ -1220,13 +1220,15 @@ Inspector* InspectorManager::get_file_inspector(const SnortConfig* sc)
 }
 
 // FIXIT-P cache get_inspector() returns or provide indexed lookup
-Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* sc)
+Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* snort_config)
 {
     InspectionPolicy* pi;
     NetworkPolicy* ni;
 
+    const SnortConfig* sc = snort_config;
     if ( !sc )
         sc = SnortConfig::get_conf();
+    assert(sc);
     if ( dflt_only )
     {
         ni = get_default_network_policy(sc);
@@ -1235,7 +1237,31 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
     else
     {
         pi = get_inspection_policy();
+        // During reload, get_network_policy will return the network policy from the new snort config
+        // for a given tenant
         ni = get_network_policy();
+        if (!snort_config)
+        {
+            // If no snort config is passed in, it means that this is either a normally running system with
+            // the correct network policy set or that get_inspector is being called from Inspector::configure
+            // and it is expecting the inspector from the running configuration and not the new snort config
+            if (ni)
+            {
+                PolicyMap* pm = sc->policy_map;
+                NetworkPolicy* np = pm->get_user_network(ni->user_policy_id);
+                if (np)
+                {
+                    // If network policy is correct, then no need to change the inspection policy
+                    if (np != ni && pi)
+                        pi = np->get_user_inspection_policy(pi->user_policy_id);
+                    ni = np;
+                }
+                else
+                    pi = nullptr;
+            }
+            else
+                pi = nullptr;
+        }
     }
 
     if ( pi )
@@ -1268,15 +1294,11 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
     return nullptr;
 }
 
-Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage,
-    InspectorType type, const SnortConfig* sc)
+Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, InspectorType type)
 {
-    if ( !sc )
-    {
-        sc = SnortConfig::get_conf();
-        if (!sc)
-            return nullptr;
-    }
+    const SnortConfig* sc = SnortConfig::get_conf();
+    if (!sc)
+        return nullptr;
 
     if (Module::GLOBAL == usage && IT_FILE == type)
     {
@@ -1302,6 +1324,10 @@ Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage,
         else if (Module::CONTEXT == usage)
         {
             NetworkPolicy* np = get_network_policy();
+            if (!np)
+                return nullptr;
+            PolicyMap* pm = sc->policy_map;
+            np = pm->get_user_network(np->user_policy_id);
             if (!np)
                 return nullptr;
             TrafficPolicy* il = np->traffic_policy;
@@ -1311,9 +1337,23 @@ Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage,
         }
         else
         {
+            NetworkPolicy* orig_np = get_network_policy();
+            if (!orig_np)
+                return nullptr;
+            PolicyMap* pm = sc->policy_map;
+            NetworkPolicy* np = pm->get_user_network(orig_np->user_policy_id);
+            if (!np)
+                return nullptr;
             InspectionPolicy* ip = get_inspection_policy();
             if (!ip)
                 return nullptr;
+            // If network policy is correct, then no need to change the inspection policy
+            if (np != orig_np)
+            {
+                ip = np->get_user_inspection_policy(ip->user_policy_id);
+                if (!ip)
+                    return nullptr;
+            }
             FrameworkPolicy* il = ip->framework_policy;
             assert(il);
             PHInstance* p = il->get_instance_by_type(key, type);
index c99da4d9b9afd2c17c0e5403045c7cab3117cc41..a13ab027afcc0a21a5203865ab157d1bc3aaaa67 100644 (file)
@@ -79,10 +79,13 @@ public:
     static InspectSsnFunc get_session(uint16_t proto);
 
     SO_PUBLIC static Inspector* get_file_inspector(const SnortConfig* = nullptr);
-    SO_PUBLIC static Inspector* get_inspector(
-        const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
-    SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType,
-        const SnortConfig* = nullptr);
+
+    // This assumes that, in a multi-tenant scenario, this is called with the correct network and inspection
+    // policies are set correctly
+    SO_PUBLIC static Inspector* get_inspector(const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
+
+    // This cannot be called in or before the inspector configure phase for a new snort config during reload
+    SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType);
 
     static Inspector* get_service_inspector_by_service(const char*);
     static Inspector* get_service_inspector_by_id(const SnortProtocolId);
index 86d07df6d2aebb349f0b0a4d958535784170e10b..105c12927ca3e001b8dd485a3877212c8eb45d28 100644 (file)
@@ -33,7 +33,6 @@
 THREAD_LOCAL const snort::Trace* snort_trace = nullptr;
 
 std::shared_ptr<PolicyTuple> PolicyMap::get_policies(Shell*) { return nullptr; }
-NetworkPolicy* PolicyMap::get_user_network(uint64_t) const { return nullptr; }
 void InspectionPolicy::configure() { }
 void BinderModule::add(const char*, const char*) { }
 void BinderModule::add(unsigned, const char*) { }
index c161e7eee0dd23b6c644f8b1a9f6c103ee6338fc..f0115223110e81ddd45e4ec050defbfac887620a 100644 (file)
@@ -35,8 +35,12 @@ bool Inspector::is_inactive() { return true; }
 
 NetworkPolicy* snort::get_network_policy()
 { return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); }
+NetworkPolicy* PolicyMap::get_user_network(uint64_t) const
+{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); }
 InspectionPolicy* snort::get_inspection_policy()
 { return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); }
+InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t) const
+{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); }
 
 InspectionPolicy::InspectionPolicy(PolicyId)
 { InspectorManager::new_policy(this, nullptr); }
index ec9f8298674c073ecf4ddcd64fd9297264cd7b7a..f09395a98bb6b319ea3cae1fed6731de6f59e2dd 100644 (file)
@@ -629,7 +629,7 @@ bool Binder::configure(SnortConfig* sc)
             default:            name = nullptr; break;
         }
         if (name)
-            default_ssn_inspectors[proto] = InspectorManager::get_inspector(name);
+            default_ssn_inspectors[proto] = InspectorManager::get_inspector(name, false, sc);
     }
 
     DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::PKT_WITHOUT_FLOW, new NonFlowPacketHandler());
index d553aa9842eac8f96b6542dbc151f04d538c3281..eade90a63e4e7c76cef3fd59719e6e48fe0e1625 100644 (file)
@@ -176,13 +176,13 @@ int FTPCheckConfigs(SnortConfig* sc, void* pData)
         return rval;
 
     //  Verify that FTP client and FTP data inspectors are initialized.
-    if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false))
+    if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false, sc))
     {
         ParseError("ftp_server requires that %s also be configured.", FTP_CLIENT_NAME);
         return -1;
     }
 
-    if(!InspectorManager::get_inspector(FTP_DATA_NAME, false))
+    if(!InspectorManager::get_inspector(FTP_DATA_NAME, false, sc))
     {
         ParseError("ftp_server requires that %s also be configured.", FTP_DATA_NAME);
         return -1;
index 8c50da4074554ffbe85f1ff59c3f41be49dd6dc9..b5f3aef9e606dbc0ce2c2c5165ba0e47d166b380 100644 (file)
@@ -153,9 +153,9 @@ int GtpInspect::get_message_type(int version, const char* name)
     return -1;
 }
 
-int get_message_type(int version, const char* name)
+int get_message_type(int version, const char* name, snort::SnortConfig* sc)
 {
-    GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME);
+    GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc);
 
     if ( !ins )
         return -1;
@@ -175,9 +175,9 @@ int GtpInspect::get_info_type(int version, const char* name)
     return -1;
 }
 
-int get_info_type(int version, const char* name)
+int get_info_type(int version, const char* name, SnortConfig* sc)
 {
-    GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME);
+    GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc);
 
     if ( !ins )
         return -1;
index a0b9ca6f137608498353f35a4a4fe1d3f2ee73fc..3463fc8a71ca9b8fba48acab966ea8e18278e541 100644 (file)
@@ -48,8 +48,13 @@ public:
     GTP_Roptions ropts;
 };
 
-int get_message_type(int version, const char* name);
-int get_info_type(int version, const char* name);
+namespace snort
+{
+struct SnortConfig;
+}
+
+int get_message_type(int version, const char* name, snort::SnortConfig*);
+int get_info_type(int version, const char* name, snort::SnortConfig*);
 
 struct GTP_IEData* get_infos();
 
index 2264acbe9db3869c504357bae15c52566b34cf01..aa38d50996e9c910a91524e6453e3be2c208c883 100644 (file)
@@ -151,7 +151,7 @@ public:
     bool set(const char*, Value&, SnortConfig*) override;
 
     bool set_types(long);
-    bool set_types(const char*);
+    bool set_types(const char*, SnortConfig*);
 
     ProfileStats* get_profile() const override
     { return &gtp_info_prof; }
@@ -174,13 +174,13 @@ bool GtpInfoModule::set_types(long t)
     return true;
 }
 
-bool GtpInfoModule::set_types(const char* name)
+bool GtpInfoModule::set_types(const char* name, SnortConfig* sc)
 {
     bool ok = false;
 
     for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v )
     {
-        int t = get_info_type(v, name);
+        int t = get_info_type(v, name, sc);
 
         if ( t < 0 )
             continue;
@@ -191,7 +191,7 @@ bool GtpInfoModule::set_types(const char* name)
     return ok;
 }
 
-bool GtpInfoModule::set(const char*, Value& v, SnortConfig*)
+bool GtpInfoModule::set(const char*, Value& v, SnortConfig* sc)
 {
     assert(v.is("~"));
     long n;
@@ -199,7 +199,7 @@ bool GtpInfoModule::set(const char*, Value& v, SnortConfig*)
     if ( v.strtol(n) )
         return set_types(n);
 
-    return set_types(v.get_string());
+    return set_types(v.get_string(), sc);
 }
 
 //-------------------------------------------------------------------------
index 82a53df4e96bcdf66c424c29c1531bd41ed12836..5ebea1468112e66242d3954309ce5e8d9373cf76 100644 (file)
@@ -138,7 +138,7 @@ public:
     bool set(const char*, Value&, SnortConfig*) override;
 
     bool set_types(long);
-    bool set_types(const char*);
+    bool set_types(const char*, SnortConfig*);
 
     ProfileStats* get_profile() const override
     { return &gtp_type_prof; }
@@ -169,13 +169,13 @@ bool GtpTypeModule::set_types(long t)
     return true;
 }
 
-bool GtpTypeModule::set_types(const char* name)
+bool GtpTypeModule::set_types(const char* name, SnortConfig* sc)
 {
     bool ok = false;
 
     for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v )
     {
-        int t = get_message_type(v, name);
+        int t = get_message_type(v, name, sc);
 
         if ( t < 0 )
             continue;
@@ -186,7 +186,7 @@ bool GtpTypeModule::set_types(const char* name)
     return ok;
 }
 
-bool GtpTypeModule::set(const char*, Value& v, SnortConfig*)
+bool GtpTypeModule::set(const char*, Value& v, SnortConfig* sc)
 {
     assert(v.is("~"));
     v.set_first_token();
@@ -210,7 +210,7 @@ bool GtpTypeModule::set(const char*, Value& v, SnortConfig*)
             if ( !set_types(n) )
                 return false;
         }
-        else if ( !set_types(tok.c_str()) )
+        else if ( !set_types(tok.c_str(), sc) )
             return false;
     }
     return true;