bool setup(SnortConfig*) override;
void cleanup(SnortConfig*) override;
-
+ void update(SnortConfig*) override
+ { }
bool allocate(hs_database_t*);
hs_scratch_t* get()
virtual bool setup(SnortConfig*) = 0;
virtual void cleanup(SnortConfig*) = 0;
+ virtual void update(SnortConfig*) = 0;
int get_id() { return id; }
typedef bool (* ScratchSetup)(SnortConfig*);
typedef void (* ScratchCleanup)(SnortConfig*);
+typedef void (* ScratchUpdate)(SnortConfig*);
class SO_PUBLIC SimpleScratchAllocator : public ScratchAllocator
{
public:
- SimpleScratchAllocator(ScratchSetup fs, ScratchCleanup fc) : fsetup(fs), fcleanup(fc) { }
+ SimpleScratchAllocator(ScratchSetup fs, ScratchCleanup fc, ScratchUpdate fu = nullptr)
+ : fsetup(fs), fcleanup(fc), fupdate(fu) { }
bool setup(SnortConfig* sc) override
{ return fsetup(sc); }
void cleanup(SnortConfig* sc) override
{ fcleanup(sc); }
+ void update(SnortConfig* sc) override
+ {
+ if (fupdate)
+ fupdate(sc);
+ }
+
private:
ScratchSetup fsetup;
ScratchCleanup fcleanup;
+ ScratchUpdate fupdate;
};
}
LogMessage("== daq module reload complete\n");
}
+bool ACScratchUpdate::execute(Analyzer&, void**)
+{
+ for ( auto* s : handlers )
+ {
+ s->update(sc);
+ }
+ return true;
+}
+
+ACScratchUpdate::~ACScratchUpdate()
+{
+ LogMessage("== scratch update complete\n");
+ ReloadTracker::end(ctrlcon);
+}
+
SFDAQInstance* AnalyzerCommand::get_daq_instance(Analyzer& analyzer)
{
return analyzer.get_daq_instance();
namespace snort
{
+class ScratchAllocator;
+struct SnortConfig;
+
class SFDAQInstance;
class AnalyzerCommand
~ACDAQSwap() override;
};
+class ACScratchUpdate : public snort::AnalyzerCommand
+{
+public:
+ ACScratchUpdate(snort::SnortConfig* sc, std::vector<snort::ScratchAllocator*>& handlers,
+ ControlConn* conn) : AnalyzerCommand(conn), sc(sc), handlers(handlers)
+ { }
+ bool execute(Analyzer&, void**) override;
+ const char* stringify() override { return "SCRATCH_UPDATE"; }
+ ~ACScratchUpdate() override;
+private:
+ snort::SnortConfig* sc;
+ std::vector<snort::ScratchAllocator*>& handlers;
+};
+
namespace snort
{
// from main.cc
}
}
+void SnortConfig::update_scratch(ControlConn* ctrlcon)
+{
+ main_broadcast_command(new ACScratchUpdate(this, scratch_handlers, ctrlcon));
+}
+
void SnortConfig::clone(const SnortConfig* const conf)
{
*this = *conf;
};
class ConfigOutput;
+class ControlConn;
class FastPatternConfig;
class RuleStateMap;
class TraceConfig;
void setup();
void post_setup();
+ void update_scratch(ControlConn*);
bool verify() const;
void merge(const SnortConfig*);
file_path = std::string(ctxt.config.app_detector_dir) + "/../userappid.conf";
ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path);
}
- ReloadTracker::end(ctrlcon);
+ SnortConfig::get_main_conf()->update_scratch(ctrlcon);
log_message("== reload detectors complete\n");
}
hs_scratch_t** ss = (hs_scratch_t**) &sc->state[i][scratch_index];
hs_clone_scratch(max, ss);
}
- hs_free_scratch(max);
+ s_scratch[get_instance_id()] = max;
return true;
}
}
}
+static bool need_update(SnortConfig* sc)
+{
+ if ( s_scratch.size() )
+ {
+ size_t max_sz, instance_sz;
+ hs_scratch_size(s_scratch[0], &max_sz);
+ hs_scratch_size((hs_scratch_t*)sc->state[get_instance_id()][scratch_index], &instance_sz);
+ if ( max_sz > instance_sz )
+ return true;
+ else
+ return false;
+ }
+ else
+ return false;
+}
+
+void static scratch_update(SnortConfig* sc)
+{
+ if ( !need_update(sc) )
+ return;
+
+ hs_scratch_t** ss = (hs_scratch_t**) &sc->state[get_instance_id()][scratch_index];
+ hs_free_scratch(*ss);
+ *ss = nullptr;
+ hs_clone_scratch(s_scratch[0], ss);
+}
+
class HyperscanModule : public Module
{
public:
HyperscanModule() : Module(s_name, s_help)
{
- scratcher = new SimpleScratchAllocator(scratch_setup, scratch_cleanup);
+ scratcher = new SimpleScratchAllocator(scratch_setup, scratch_cleanup, scratch_update);
scratch_index = scratcher->get_id();
}
~HyperscanModule() override
- { delete scratcher; }
+ {
+ delete scratcher;
+
+ for ( auto& ss : s_scratch )
+ {
+ if ( ss )
+ {
+ hs_free_scratch(ss);
+ ss = nullptr;
+ }
+ }
+ }
};
//-------------------------------------------------------------------------
#endif
#include <cstring>
+#include <hs_runtime.h>
#include "framework/base_api.h"
#include "framework/counts.h"
CHECK(hits == 1);
}
+//-------------------------------------------------------------------------
+// scratch update test
+//-------------------------------------------------------------------------
+
+TEST_GROUP(mpse_hs_scratch)
+{
+ Module* mod = nullptr;
+ Mpse* hs1 = nullptr;
+ Mpse* hs2 = nullptr;
+ const MpseApi* mpse_api = (const MpseApi*)se_hyperscan;
+
+ void setup() override
+ {
+ CHECK(se_hyperscan);
+ mod = mpse_api->base.mod_ctor();
+ hs1 = mpse_api->ctor(snort_conf, nullptr, nullptr);
+ hs2 = mpse_api->ctor(snort_conf, nullptr, nullptr);
+ CHECK(hs1);
+ CHECK(hs2);
+ }
+ void teardown() override
+ {
+ mpse_api->dtor(hs1);
+ mpse_api->dtor(hs2);
+ scratcher->cleanup(snort_conf);
+ mpse_api->base.mod_dtor(mod);
+ }
+};
+
+TEST(mpse_hs_scratch, scratch_update)
+{
+ Mpse::PatternDescriptor desc;
+ CHECK(hs1->add_pattern((const uint8_t*)"foo", 3, desc, s_user) == 0);
+ CHECK(hs1->prep_patterns(snort_conf) == 0);
+ CHECK(hs1->get_pattern_count() == 1);
+
+ scratcher->setup(snort_conf);
+
+ size_t instance_sz1;
+ hs_scratch_size((hs_scratch_t*)snort_conf->state[0][0], &instance_sz1);
+
+ for (unsigned i = 0; i < 30; i++)
+ CHECK(hs2->add_pattern((const uint8_t*)"bar", 3, desc, s_user) == 0);
+ CHECK(hs2->prep_patterns(snort_conf) == 0);
+ CHECK(hs2->get_pattern_count() == 30);
+
+ scratcher->update(snort_conf);
+
+ size_t instance_sz2;
+ hs_scratch_size((hs_scratch_t*)snort_conf->state[0][0], &instance_sz2);
+ CHECK(instance_sz2 > instance_sz1);
+}
+
//-------------------------------------------------------------------------
// main
//-------------------------------------------------------------------------