public:
InspectionModule() : Module("inspection", inspection_help, inspection_params) { }
bool set(const char*, Value&, SnortConfig*) override;
- bool end(const char*, int, SnortConfig*) override;
Usage get_usage() const override
{ return INSPECT; }
return true;
}
-bool InspectionModule::end(const char*, int, SnortConfig*)
-{
- InspectionPolicy* p = get_inspection_policy();
- NetworkPolicy* np = get_network_parse_policy();
- assert(np);
- np->set_user_inspection(p);
- return true;
-}
-
//-------------------------------------------------------------------------
// Ips policy module
//-------------------------------------------------------------------------
void NetworkPolicy::add_file_policy_rule(FileRule& file_rule)
{ file_policy->add_file_id(file_rule); }
-InspectionPolicy* NetworkPolicy::get_user_inspection_policy(unsigned user_id)
+void NetworkPolicy::setup_inspection_policies()
+{
+ std::for_each(inspection_policy.begin(), inspection_policy.end(),
+ [this](InspectionPolicy* ip){ set_user_inspection(ip); });
+}
+
+InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t user_id) const
{
auto it = user_inspection.find(user_id);
return it == user_inspection.end() ? nullptr : it->second;
bool PolicyMap::set_user_network(NetworkPolicy* p)
{
NetworkPolicy* current_np = get_user_network(p->user_policy_id);
- if (current_np && p != current_np)
- return false;
+ if (current_np)
+ return p == current_np;
user_network[p->user_policy_id] = p;
+ p->setup_inspection_policies();
return true;
}
void set_ips_policy(IpsPolicy* p)
{ s_detection_policy = p; }
-InspectionPolicy* get_user_inspection_policy(unsigned policy_id)
+InspectionPolicy* get_user_inspection_policy(uint64_t policy_id)
{
NetworkPolicy* np = get_network_policy();
assert(np);
SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*);
// Based on currently set network policy
-SO_PUBLIC InspectionPolicy* get_user_inspection_policy(unsigned policy_id);
+SO_PUBLIC InspectionPolicy* get_user_inspection_policy(uint64_t policy_id);
SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0);
// Based on currently set network policy
{ return i < inspection_policy.size() ? inspection_policy[i] : nullptr; }
unsigned inspection_policy_count()
{ return inspection_policy.size(); }
- InspectionPolicy* get_user_inspection_policy(unsigned user_id);
+ void setup_inspection_policies();
+ InspectionPolicy* get_user_inspection_policy(uint64_t user_id) const;
void set_user_inspection(InspectionPolicy* p)
{ user_inspection[p->user_policy_id] = p; }
}
// FIXIT-P cache get_inspector() returns or provide indexed lookup
-Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* sc)
+Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, const SnortConfig* snort_config)
{
InspectionPolicy* pi;
NetworkPolicy* ni;
+ const SnortConfig* sc = snort_config;
if ( !sc )
sc = SnortConfig::get_conf();
+ assert(sc);
if ( dflt_only )
{
ni = get_default_network_policy(sc);
else
{
pi = get_inspection_policy();
+ // During reload, get_network_policy will return the network policy from the new snort config
+ // for a given tenant
ni = get_network_policy();
+ if (!snort_config)
+ {
+ // If no snort config is passed in, it means that this is either a normally running system with
+ // the correct network policy set or that get_inspector is being called from Inspector::configure
+ // and it is expecting the inspector from the running configuration and not the new snort config
+ if (ni)
+ {
+ PolicyMap* pm = sc->policy_map;
+ NetworkPolicy* np = pm->get_user_network(ni->user_policy_id);
+ if (np)
+ {
+ // If network policy is correct, then no need to change the inspection policy
+ if (np != ni && pi)
+ pi = np->get_user_inspection_policy(pi->user_policy_id);
+ ni = np;
+ }
+ else
+ pi = nullptr;
+ }
+ else
+ pi = nullptr;
+ }
}
if ( pi )
return nullptr;
}
-Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage,
- InspectorType type, const SnortConfig* sc)
+Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage, InspectorType type)
{
- if ( !sc )
- {
- sc = SnortConfig::get_conf();
- if (!sc)
- return nullptr;
- }
+ const SnortConfig* sc = SnortConfig::get_conf();
+ if (!sc)
+ return nullptr;
if (Module::GLOBAL == usage && IT_FILE == type)
{
else if (Module::CONTEXT == usage)
{
NetworkPolicy* np = get_network_policy();
+ if (!np)
+ return nullptr;
+ PolicyMap* pm = sc->policy_map;
+ np = pm->get_user_network(np->user_policy_id);
if (!np)
return nullptr;
TrafficPolicy* il = np->traffic_policy;
}
else
{
+ NetworkPolicy* orig_np = get_network_policy();
+ if (!orig_np)
+ return nullptr;
+ PolicyMap* pm = sc->policy_map;
+ NetworkPolicy* np = pm->get_user_network(orig_np->user_policy_id);
+ if (!np)
+ return nullptr;
InspectionPolicy* ip = get_inspection_policy();
if (!ip)
return nullptr;
+ // If network policy is correct, then no need to change the inspection policy
+ if (np != orig_np)
+ {
+ ip = np->get_user_inspection_policy(ip->user_policy_id);
+ if (!ip)
+ return nullptr;
+ }
FrameworkPolicy* il = ip->framework_policy;
assert(il);
PHInstance* p = il->get_instance_by_type(key, type);
static InspectSsnFunc get_session(uint16_t proto);
SO_PUBLIC static Inspector* get_file_inspector(const SnortConfig* = nullptr);
- SO_PUBLIC static Inspector* get_inspector(
- const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
- SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType,
- const SnortConfig* = nullptr);
+
+ // This assumes that, in a multi-tenant scenario, this is called with the correct network and inspection
+ // policies are set correctly
+ SO_PUBLIC static Inspector* get_inspector(const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
+
+ // This cannot be called in or before the inspector configure phase for a new snort config during reload
+ SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType);
static Inspector* get_service_inspector_by_service(const char*);
static Inspector* get_service_inspector_by_id(const SnortProtocolId);
THREAD_LOCAL const snort::Trace* snort_trace = nullptr;
std::shared_ptr<PolicyTuple> PolicyMap::get_policies(Shell*) { return nullptr; }
-NetworkPolicy* PolicyMap::get_user_network(uint64_t) const { return nullptr; }
void InspectionPolicy::configure() { }
void BinderModule::add(const char*, const char*) { }
void BinderModule::add(unsigned, const char*) { }
NetworkPolicy* snort::get_network_policy()
{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); }
+NetworkPolicy* PolicyMap::get_user_network(uint64_t) const
+{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); }
InspectionPolicy* snort::get_inspection_policy()
{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); }
+InspectionPolicy* NetworkPolicy::get_user_inspection_policy(uint64_t) const
+{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); }
InspectionPolicy::InspectionPolicy(PolicyId)
{ InspectorManager::new_policy(this, nullptr); }
default: name = nullptr; break;
}
if (name)
- default_ssn_inspectors[proto] = InspectorManager::get_inspector(name);
+ default_ssn_inspectors[proto] = InspectorManager::get_inspector(name, false, sc);
}
DataBus::subscribe(intrinsic_pub_key, IntrinsicEventIds::PKT_WITHOUT_FLOW, new NonFlowPacketHandler());
return rval;
// Verify that FTP client and FTP data inspectors are initialized.
- if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false))
+ if(!InspectorManager::get_inspector(FTP_CLIENT_NAME, false, sc))
{
ParseError("ftp_server requires that %s also be configured.", FTP_CLIENT_NAME);
return -1;
}
- if(!InspectorManager::get_inspector(FTP_DATA_NAME, false))
+ if(!InspectorManager::get_inspector(FTP_DATA_NAME, false, sc))
{
ParseError("ftp_server requires that %s also be configured.", FTP_DATA_NAME);
return -1;
return -1;
}
-int get_message_type(int version, const char* name)
+int get_message_type(int version, const char* name, snort::SnortConfig* sc)
{
- GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME);
+ GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc);
if ( !ins )
return -1;
return -1;
}
-int get_info_type(int version, const char* name)
+int get_info_type(int version, const char* name, SnortConfig* sc)
{
- GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME);
+ GtpInspect* ins = (GtpInspect*)InspectorManager::get_inspector(GTP_NAME, false, sc);
if ( !ins )
return -1;
GTP_Roptions ropts;
};
-int get_message_type(int version, const char* name);
-int get_info_type(int version, const char* name);
+namespace snort
+{
+struct SnortConfig;
+}
+
+int get_message_type(int version, const char* name, snort::SnortConfig*);
+int get_info_type(int version, const char* name, snort::SnortConfig*);
struct GTP_IEData* get_infos();
bool set(const char*, Value&, SnortConfig*) override;
bool set_types(long);
- bool set_types(const char*);
+ bool set_types(const char*, SnortConfig*);
ProfileStats* get_profile() const override
{ return >p_info_prof; }
return true;
}
-bool GtpInfoModule::set_types(const char* name)
+bool GtpInfoModule::set_types(const char* name, SnortConfig* sc)
{
bool ok = false;
for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v )
{
- int t = get_info_type(v, name);
+ int t = get_info_type(v, name, sc);
if ( t < 0 )
continue;
return ok;
}
-bool GtpInfoModule::set(const char*, Value& v, SnortConfig*)
+bool GtpInfoModule::set(const char*, Value& v, SnortConfig* sc)
{
assert(v.is("~"));
long n;
if ( v.strtol(n) )
return set_types(n);
- return set_types(v.get_string());
+ return set_types(v.get_string(), sc);
}
//-------------------------------------------------------------------------
bool set(const char*, Value&, SnortConfig*) override;
bool set_types(long);
- bool set_types(const char*);
+ bool set_types(const char*, SnortConfig*);
ProfileStats* get_profile() const override
{ return >p_type_prof; }
return true;
}
-bool GtpTypeModule::set_types(const char* name)
+bool GtpTypeModule::set_types(const char* name, SnortConfig* sc)
{
bool ok = false;
for ( int v = 0; v <= MAX_GTP_VERSION_CODE; ++v )
{
- int t = get_message_type(v, name);
+ int t = get_message_type(v, name, sc);
if ( t < 0 )
continue;
return ok;
}
-bool GtpTypeModule::set(const char*, Value& v, SnortConfig*)
+bool GtpTypeModule::set(const char*, Value& v, SnortConfig* sc)
{
assert(v.is("~"));
v.set_first_token();
if ( !set_types(n) )
return false;
}
- else if ( !set_types(tok.c_str()) )
+ else if ( !set_types(tok.c_str(), sc) )
return false;
}
return true;