#include "log/messages.h"
#include "main/snort.h"
#include "main/snort_config.h"
+#include "managers/module_manager.h"
#include "packet_io/active.h"
#include "trace/trace.h"
else if ( v.is("decompress_buffer_size") )
FileService::decode_conf.set_decompress_buffer_size(v.get_uint32());
+
else if ( v.is("rules_file") )
{
- std::string s = "include ";
- s += v.get_string();
- parser_append_rules_special(s.c_str());
+ magic_file = "include ";
+ magic_file += v.get_string();
}
return true;
}
+bool FileIdModule::end(const char*, int, SnortConfig*)
+{
+ const char* inc = ModuleManager::get_includer("file_id");
+ parser_append_rules_special(magic_file.c_str(), inc);
+ return true;
+}
+
void FileIdModule::load_config(FileConfig*& dst)
{
dst = fc;
#ifndef FILE_MODULE_H
#define FILE_MODULE_H
+#include <string>
+
#include "framework/module.h"
#include "file_config.h"
~FileIdModule() override;
bool set(const char*, snort::Value&, snort::SnortConfig*) override;
+ bool end(const char*, int, snort::SnortConfig*) override;
snort::ProfileStats* get_profile() const override;
const PegInfo* get_pegs() const override;
private:
FileMeta rule;
FileConfig *fc = nullptr;
+ std::string magic_file;
};
enum FileSid
dofile(fname)
end
- local iname = path_top()
- if ( (ips ~= nil) and (ips.includer == nil) and (iname ~= nil) ) then
- ips.includer = iname
+ if ( (ips ~= nil) and (ips.includer == nil) ) then
+ ips.includer = fname
+ end
+
+ if ( file_id ~= nil and file_id.includer == nil ) then
+ file_id.includer = fname
end
path_pop()
bool set_bool(const char*, bool);
bool set_number(const char*, double);
bool set_string(const char*, const char*);
+bool set_includer(const char*, const char*);
bool set_alias(const char*, const char*);
void clear_alias();
]]
ffi.C.set_number(name, val)
elseif ( what == 'string' ) then
- ffi.C.set_string(name, val)
+ if ( key == "includer" ) then
+ ffi.C.set_includer(name, val)
+ else
+ ffi.C.set_string(name, val)
+ end
elseif ( what == 'table' ) then
if ( ffi.C.open_table(name, idx) ) then
{ "include", Parameter::PT_STRING, nullptr, nullptr,
"snort rules and includes" },
- { "includer", Parameter::PT_STRING, "(optional)", nullptr,
- "for internal use; where includes are included from" },
-
// FIXIT-L no default; it breaks initialization by -Q
{ "mode", Parameter::PT_ENUM, "tap | inline | inline-test", nullptr,
"set policy mode" },
else if ( v.is("include") )
p->include = v.get_string();
- else if ( v.is("includer") )
- p->includer = v.get_string();
-
else if ( v.is("mode") )
p->policy_mode = (PolicyMode)v.get_uint8();
else if (!idx and !strcmp(fqn, "ips"))
{
IpsPolicy* p = get_ips_policy();
+ p->includer = ModuleManager::get_includer("ips");
sc->policy_map->set_user_ips(p);
}
return true;
static string s_current;
static string s_aliased_name;
static string s_aliased_type;
+static string s_ips_includer;
+static string s_file_id_includer;
// for callbacks from Lua
static SnortConfig* s_config = nullptr;
bool set_alias(const char* from, const char* to);
void clear_alias();
- const char* push_include_path(const char* file);
+ bool set_includer(const char* fqn, const char* val);
+ const char* push_include_path(const char*);
void pop_include_path();
+
void snort_whitelist_append(const char*);
void snort_whitelist_add_prefix(const char*);
}
pop_parse_location();
}
+// cppcheck-suppress unusedFunction
+SO_PUBLIC bool set_includer(const char* fqn, const char* s)
+{
+ if ( !strcmp(fqn, "ips.includer") )
+ s_ips_includer = s;
+ else
+ {
+ assert(!strcmp(fqn, "file_id.includer"));
+ s_file_id_includer = s;
+ }
+ return true;
+}
+
//-------------------------------------------------------------------------
// ffi methods - also called internally so no cppcheck suppressions
//-------------------------------------------------------------------------
unsigned ModuleManager::get_errors()
{ return s_errors; }
+const char* ModuleManager::get_includer(const char* mod)
+{
+ assert(!strcmp(mod, "ips") or !strcmp(mod, "file_id"));
+
+ if ( !strcmp(mod, "ips") )
+ return s_ips_includer.c_str();
+
+ return s_file_id_includer.c_str();
+}
+
void ModuleManager::list_modules(const char* s)
{
PlugType pt = s ? PluginManager::get_type(s) : PT_MAX;
SO_PUBLIC static std::list<Module*> get_all_modules();
static const char* get_lua_coreinit();
+ static const char* get_includer(const char* module);
static void list_modules(const char* = nullptr);
static void dump_modules();
static std::stack<Location> files;
static int rules_file_depth = 0;
+static bool s_ips_policy = true;
const char* get_parse_file()
{
{
assert(arg);
std::string conf = ExpandVars(arg);
- std::string file = !rules_file_depth ? get_ips_policy()->includer : get_parse_file();
+ std::string file;
+
+ if ( rules_file_depth )
+ file = get_parse_file();
+
+ else if ( s_ips_policy )
+ file = get_ips_policy()->includer;
+
+ else
+ file = parser_get_special_includer();
const char* code = get_config_file(conf.c_str(), file);
--rules_file_depth;
}
-void parse_rules_string(SnortConfig* sc, const char* s)
+void parse_rules_string(SnortConfig* sc, const char* s, bool ips_policy)
{
+ s_ips_policy = ips_policy;
std::string rules = s;
std::stringstream ss(rules);
parse_stream(ss, sc);
+ s_ips_policy = true;
}
const char* get_config_file(const char* arg, std::string& file);
void parse_rules_file(snort::SnortConfig*, const char* fname);
-void parse_rules_string(snort::SnortConfig*, const char* str);
+void parse_rules_string(snort::SnortConfig*, const char* str, bool ips_policy = true);
void ParseIpVar(const char* name, const char* value);
void parse_include(snort::SnortConfig*, const char*);
static std::string s_aux_rules;
static std::string s_special_rules;
+static std::string s_special_includer;
class RuleTreeHashKeyOps : public HashKeyOperations
{
if (!idx and !s_special_rules.empty())
{
- push_parse_location("W", "./", "rule args");
- parse_rules_string(sc, s_special_rules.c_str());
+ push_parse_location("W", "./", "file_id.rules_file");
+ parse_rules_string(sc, s_special_rules.c_str(), false);
pop_parse_location();
s_special_rules.clear();
}
s_aux_rules += "\n";
}
-void parser_append_rules_special(const char *s)
+void parser_append_rules_special(const char *s, const char* inc)
{
s_special_rules += s;
s_special_rules += "\n";
+ s_special_includer = inc;
}
+const char* parser_get_special_includer()
+{ return s_special_includer.c_str(); }
+
void parser_append_includes(const char* d)
{
Directory dir(d);
void parser_append_rules(const char*);
void parser_append_includes(const char*);
-void parser_append_rules_special(const char *);
+void parser_append_rules_special(const char* file, const char* includer);
+const char* parser_get_special_includer();
int ParseBool(const char* arg);