From: Brandon Stultz (brastult) Date: Fri, 24 Oct 2025 09:00:55 +0000 (+0000) Subject: Pull request #4833: snort_ml: add mpse and lru cache X-Git-Tag: 3.9.7.0~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=72b83b5a93074f769fe60aca490809a7e0d48059;p=thirdparty%2Fsnort3.git Pull request #4833: snort_ml: add mpse and lru cache Merge in SNORT/snort3 from ~BRASTULT/snort3:snort_ml_pipeline to master Squashed commit of the following: commit 1f51dd1bee92a4995d960561b59a72e1a8903b53 Author: Brandon Stultz Date: Fri Jul 25 13:46:00 2025 -0400 build: only enable libml for supported versions commit 47a789fc3b637f95b11ba0b154af53440ed5b2f2 Author: Brandon Stultz Date: Fri Jul 25 13:32:01 2025 -0400 snort_ml: add mpse and lru cache commit 7c74729080cc2f1095dbbeee8e98bbbda00accf9 Author: Brandon Stultz Date: Fri Sep 5 17:00:03 2025 -0400 hash: add FNV-1a hash --- diff --git a/cmake/FindML.cmake b/cmake/FindML.cmake index c30553c95..15943ffad 100644 --- a/cmake/FindML.cmake +++ b/cmake/FindML.cmake @@ -3,12 +3,14 @@ pkg_check_modules(PC_ML libml_static>=2.0.0) find_path(ML_INCLUDE_DIRS libml.h - HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR} + PATHS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR} + NO_DEFAULT_PATH ) find_library(ML_LIBRARIES NAMES ml_static - HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR} + PATHS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR} + NO_DEFAULT_PATH ) include(FindPackageHandleStandardArgs) diff --git a/src/hash/fnv.h b/src/hash/fnv.h new file mode 100644 index 000000000..48c85916e --- /dev/null +++ b/src/hash/fnv.h @@ -0,0 +1,41 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2025-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// fnv.h author Brandon Stultz +// based on https://datatracker.ietf.org/doc/html/draft-eastlake-fnv + +#ifndef FNV_H +#define FNV_H + +#define FNV_PRIME 0x00000100000001B3 +#define FNV_BASIS 0xCBF29CE484222325 + +//-------------------------------------------------------------------------- +// FNV-1a Hash +//-------------------------------------------------------------------------- + +inline uint64_t fnv1a(const char* buf, const size_t len) +{ + uint64_t result = FNV_BASIS; + + for (size_t i = 0; i < len; i++) + result = (result ^ static_cast(buf[i])) * FNV_PRIME; + + return result; +} + +#endif diff --git a/src/network_inspectors/snort_ml/snort_ml_engine.cc b/src/network_inspectors/snort_ml/snort_ml_engine.cc index 6b3c63c96..8d69174a6 100644 --- a/src/network_inspectors/snort_ml/snort_ml_engine.cc +++ b/src/network_inspectors/snort_ml/snort_ml_engine.cc @@ -25,13 +25,11 @@ #include "snort_ml_engine.h" #include +#include #include -#ifdef HAVE_LIBML -#include -#endif - #include "framework/decode_data.h" +#include "hash/fnv.h" #include "helpers/directory.h" #include "log/messages.h" #include "main/reload_tuner.h" @@ -43,36 +41,110 @@ using namespace snort; using namespace std; -static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr; +static THREAD_LOCAL SnortMLEngineStats snort_ml_engine_stats; +static THREAD_LOCAL SnortMLContext* snort_ml_ctx = nullptr; -static bool build_classifiers(const vector& models, - libml::BinaryClassifierSet*& set) +static SnortMLContext* create_context(const SnortMLEngineConfig& conf) { - set = new libml::BinaryClassifierSet(); + SnortMLContext* ctx = new SnortMLContext(); + + if (!ctx->classifiers.build(conf.http_param_models)) + { + ErrorMessage("Could not build classifiers.\n"); + return ctx; + } + + if (conf.cache_memcap > 0) + { + ctx->cache = make_unique(conf.cache_memcap, + snort_ml_engine_stats); + } - return set->build(models); + return ctx; } //-------------------------------------------------------------------------- // module //-------------------------------------------------------------------------- +static const Parameter filter_param[] = +{ + { "filter_pattern", Parameter::PT_STRING, nullptr, nullptr, + "pattern that triggers ML classification" }, + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + +static const Parameter ignore_param[] = +{ + { "ignore_pattern", Parameter::PT_STRING, nullptr, nullptr, + "pattern that skips ML classification" }, + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + static const Parameter snort_ml_engine_params[] = { - { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" }, + { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, + "path to model file(s)" }, + + { "http_param_filter", Parameter::PT_LIST, filter_param, nullptr, + "list of patterns that trigger ML classification" }, + + { "http_param_ignore", Parameter::PT_LIST, ignore_param, nullptr, + "list of patterns that skip ML classification" }, + + { "cache_memcap", Parameter::PT_INT, "0:maxSZ", "0", + "maximum memory for verdict cache in bytes, 0 = disabled" }, + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; -SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {} +static const PegInfo peg_names[] = +{ + LRU_CACHE_LOCAL_PEGS("snort_ml_engine"), + { CountType::SUM, "filter_searches", "total filter searches" }, + { CountType::SUM, "filter_matches", "total filter matches" }, + { CountType::SUM, "filter_allows", "total filter allows" }, + { CountType::SUM, "libml_calls", "total libml calls" }, + { CountType::END, nullptr, nullptr } +}; + +SnortMLEngineModule::SnortMLEngineModule() + : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {} + +bool SnortMLEngineModule::begin(const char* fqn, int, SnortConfig*) +{ + if (!strcmp(SNORT_ML_ENGINE_NAME, fqn)) + conf = {}; + + return true; +} bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*) { if (v.is("http_param_model")) conf.http_param_model_path = v.get_string(); + else if (v.is("filter_pattern")) + conf.http_param_filters[v.get_string()] = true; + + else if (v.is("ignore_pattern")) + { + conf.http_param_filters[v.get_string()] = false; + conf.has_allow = true; + } + + else if (v.is("cache_memcap")) + conf.cache_memcap = v.get_size(); + return true; } +const PegInfo* SnortMLEngineModule::get_pegs() const +{ return peg_names; } + +PegCount* SnortMLEngineModule::get_counts() const +{ return reinterpret_cast(&snort_ml_engine_stats); } + //-------------------------------------------------------------------------- // reload tuner //-------------------------------------------------------------------------- @@ -80,9 +152,7 @@ bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*) class SnortMLReloadTuner : public snort::ReloadResourceTuner { public: - explicit SnortMLReloadTuner(const vector& models) - : http_param_models(models) {} - + explicit SnortMLReloadTuner(const SnortMLEngineConfig& c) : conf(c) {} ~SnortMLReloadTuner() override = default; const char* name() const override @@ -90,11 +160,8 @@ public: bool tinit() override { - delete classifiers; - - if (!build_classifiers(http_param_models, classifiers)) - ErrorMessage("Could not build classifiers.\n"); - + delete snort_ml_ctx; + snort_ml_ctx = create_context(conf); return false; } @@ -105,25 +172,48 @@ public: { return true; } private: - const vector& http_param_models; + const SnortMLEngineConfig& conf; }; //-------------------------------------------------------------------------- // inspector //-------------------------------------------------------------------------- -SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c) +bool SnortMLEngine::configure(SnortConfig*) { - if (!read_models() || !validate_models()) + if (!read_models()) + return false; + + libml::BinaryClassifierSet classifiers; + + if (!classifiers.build(conf.http_param_models)) + { ParseError("Could not build classifiers."); + return false; + } + + if (!conf.http_param_filters.empty()) + { + mpse = new SearchTool; + + for (auto& f : conf.http_param_filters) + mpse->add(f.first.c_str(), f.first.size(), (void*)&f); + + mpse->prep(); + } + + return true; } void SnortMLEngine::show(const SnortConfig*) const -{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); } +{ + ConfigLogger::log_value("http_param_model", conf.http_param_model_path.c_str()); + ConfigLogger::log_value("cache_memcap", conf.cache_memcap); +} bool SnortMLEngine::read_models() { - const char* hint = config.http_param_model_path.c_str(); + const char* hint = conf.http_param_model_path.c_str(); string path; if (!get_config_file(hint, path)) @@ -160,7 +250,13 @@ bool SnortMLEngine::read_models() } } - return !http_param_models.empty(); + if (conf.http_param_models.empty()) + { + ParseError("snort_ml_engine: no models found"); + return false; + } + + return true; } bool SnortMLEngine::read_model(const string& path) @@ -178,33 +274,80 @@ bool SnortMLEngine::read_model(const string& path) string buffer(size, '\0'); file.read(&buffer[0], streamsize(size)); - http_param_models.push_back(move(buffer)); + conf.http_param_models.push_back(std::move(buffer)); return true; } -bool SnortMLEngine::validate_models() -{ - libml::BinaryClassifierSet* set = nullptr; - bool res = build_classifiers(http_param_models, set); - delete set; - - return res; -} - void SnortMLEngine::tinit() -{ build_classifiers(http_param_models, classifiers); } +{ snort_ml_ctx = create_context(conf); } void SnortMLEngine::tterm() { - delete classifiers; - classifiers = nullptr; + delete snort_ml_ctx; + snort_ml_ctx = nullptr; } void SnortMLEngine::install_reload_handler(SnortConfig* sc) -{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); } +{ sc->register_reload_handler(new SnortMLReloadTuner(conf)); } + +static int filter_match_callback(void* f, void*, int, void* s, void*) +{ + auto filter = reinterpret_cast*>(f); + auto search = reinterpret_cast(s); + + search->match = true; + search->allow |= !filter->second; + + if (search->has_allow && !search->allow) + return 0; + + return 1; +} + +bool SnortMLEngine::scan(const char* buf, const size_t len, float& out) const +{ + if (!snort_ml_ctx) + return false; + + if (mpse) + { + snort_ml_engine_stats.filter_searches++; + + SnortMLSearch search; + search.has_allow = conf.has_allow; + + mpse->find_all(buf, len, filter_match_callback, + false, (void*)&search); -libml::BinaryClassifierSet* SnortMLEngine::get_classifiers() -{ return classifiers; } + if (!search.match) + return false; + + snort_ml_engine_stats.filter_matches++; + + if (search.allow) + { + snort_ml_engine_stats.filter_allows++; + return false; + } + } + + float res = 0; + bool is_new = true; + + float& result = (snort_ml_ctx->cache) ? + snort_ml_ctx->cache->find_else_create(fnv1a(buf, len), &is_new) : res; + + if (is_new) + { + snort_ml_engine_stats.libml_calls++; + + if (!snort_ml_ctx->classifiers.run(buf, len, result)) + return false; + } + + out = result; + return true; +} //-------------------------------------------------------------------------- // api stuff @@ -218,7 +361,7 @@ static void mod_dtor(Module* m) static Inspector* snort_ml_engine_ctor(Module* m) { - SnortMLEngineModule* mod = (SnortMLEngineModule*)m; + SnortMLEngineModule* mod = reinterpret_cast(m); return new SnortMLEngine(mod->get_config()); } @@ -274,8 +417,9 @@ const BaseApi* nin_snort_ml_engine[] = TEST_CASE("SnortML tuner name", "[snort_ml_module]") { - const vector models = { "model" }; - SnortMLReloadTuner tuner(models); + SnortMLEngineConfig conf; + conf.http_param_models = { "model" }; + SnortMLReloadTuner tuner(conf); REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0); } diff --git a/src/network_inspectors/snort_ml/snort_ml_engine.h b/src/network_inspectors/snort_ml/snort_ml_engine.h index 0d77d1664..beeae3389 100644 --- a/src/network_inspectors/snort_ml/snort_ml_engine.h +++ b/src/network_inspectors/snort_ml/snort_ml_engine.h @@ -21,20 +21,83 @@ #ifndef SNORT_ML_ENGINE_H #define SNORT_ML_ENGINE_H -#include "framework/module.h" +#ifdef HAVE_LIBML +#include +#endif + +#include +#include +#include + #include "framework/inspector.h" +#include "framework/module.h" +#include "hash/lru_cache_local.h" +#include "search_engines/search_tool.h" #define SNORT_ML_ENGINE_NAME "snort_ml_engine" #define SNORT_ML_ENGINE_HELP "configure machine learning engine settings" +// Mock BinaryClassifierSet for tests if LibML is absent +#ifndef HAVE_LIBML namespace libml { - class BinaryClassifierSet; + +class BinaryClassifierSet +{ +public: + bool build(const std::vector& models) + { + if (!models.empty()) + pattern = models[0]; + + return pattern != "error"; + } + + bool run(const char* ptr, size_t len, float& out) + { + std::string data(ptr, len); + out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f; + return pattern != "fail"; + } + +private: + std::string pattern; +}; + } +#endif + +struct SnortMLEngineStats : public LruCacheLocalStats +{ + PegCount filter_searches; + PegCount filter_matches; + PegCount filter_allows; + PegCount libml_calls; +}; + +typedef LruCacheLocal> SnortMLCache; +typedef std::unordered_map SnortMLFilterMap; + +struct SnortMLContext +{ + libml::BinaryClassifierSet classifiers; + std::unique_ptr cache; +}; struct SnortMLEngineConfig { std::string http_param_model_path; + std::vector http_param_models; + SnortMLFilterMap http_param_filters; + bool has_allow = false; + size_t cache_memcap = 0; +}; + +struct SnortMLSearch +{ + bool match = false; + bool allow = false; + bool has_allow = false; }; class SnortMLEngineModule : public snort::Module @@ -42,13 +105,21 @@ class SnortMLEngineModule : public snort::Module public: SnortMLEngineModule(); + bool begin(const char*, int, snort::SnortConfig*) override; bool set(const char*, snort::Value&, snort::SnortConfig*) override; + const PegInfo* get_pegs() const override; + PegCount* get_counts() const override; + Usage get_usage() const override { return GLOBAL; } - const SnortMLEngineConfig& get_config() - { return conf; } + SnortMLEngineConfig get_config() + { + SnortMLEngineConfig out; + std::swap(conf, out); + return out; + } private: SnortMLEngineConfig conf; @@ -57,58 +128,26 @@ private: class SnortMLEngine : public snort::Inspector { public: - SnortMLEngine(const SnortMLEngineConfig&); + SnortMLEngine(SnortMLEngineConfig c) : conf(std::move(c)) {} + ~SnortMLEngine() override + { delete mpse; } + bool configure(snort::SnortConfig*) override; void show(const snort::SnortConfig*) const override; - void eval(snort::Packet*) override {} void tinit() override; void tterm() override; void install_reload_handler(snort::SnortConfig*) override; - static libml::BinaryClassifierSet* get_classifiers(); + bool scan(const char*, const size_t, float&) const; private: bool read_models(); bool read_model(const std::string&); - bool validate_models(); - - SnortMLEngineConfig config; - std::vector http_param_models; -}; - -// Mock BinaryClassifierSet for tests if LibML is absent. -// The code below won't be executed if REG_TEST is undefined. -// Check the plugin type provided in the snort_ml_engine.cc file. -#ifndef HAVE_LIBML -namespace libml -{ - -class BinaryClassifierSet -{ -public: - bool build(const std::vector& models) - { - if (!models.empty()) - pattern = models[0]; - - return pattern != "error"; - } - - bool run(const char* ptr, size_t len, float& out) - { - std::string data(ptr, len); - out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f; - return pattern != "fail"; - } - -private: - std::string pattern; + SnortMLEngineConfig conf; + snort::SearchTool* mpse = nullptr; }; -} -#endif - #endif diff --git a/src/network_inspectors/snort_ml/snort_ml_inspector.cc b/src/network_inspectors/snort_ml/snort_ml_inspector.cc index 9d27a5817..64a9f62a7 100644 --- a/src/network_inspectors/snort_ml/snort_ml_inspector.cc +++ b/src/network_inspectors/snort_ml/snort_ml_inspector.cc @@ -25,10 +25,6 @@ #include -#ifdef HAVE_LIBML -#include -#endif - #include "detection/detection_engine.h" #include "log/messages.h" #include "managers/inspector_manager.h" @@ -51,13 +47,14 @@ THREAD_LOCAL ProfileStats snort_ml_prof; class HttpBodyHandler : public DataHandler { public: - HttpBodyHandler(SnortML& ml) - : DataHandler(SNORT_ML_NAME), inspector(ml) {} + HttpBodyHandler(const SnortMLEngine& eng, const SnortML& ins) + : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {} void handle(DataEvent& de, Flow*) override; private: - SnortML& inspector; + const SnortMLEngine& engine; + const SnortML& inspector; }; void HttpBodyHandler::handle(DataEvent& de, Flow*) @@ -65,9 +62,7 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*) // cppcheck-suppress unreadVariable Profile profile(snort_ml_prof); - libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers(); - SnortMLConfig config = inspector.get_config(); - HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de; + HttpRequestBodyEvent* he = reinterpret_cast(&de); if (he->is_mime()) return; @@ -78,23 +73,23 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*) if (!body || body_len <= 0) return; - const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len); - - assert(classifiers); - - float output = 0.0; + const SnortMLConfig& conf = inspector.get_config(); - snort_ml_stats.libml_calls++; + const size_t len = std::min((size_t)conf.client_body_depth, (size_t)body_len); - if (!classifiers->run(body, len, output)) + float output = 0; + if (!engine.scan(body, len, output)) return; snort_ml_stats.client_body_bytes += len; - debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body); - debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, + "input (body): %.*s\n", (int)len, body); - if ((double)output > config.http_param_threshold) + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, + "output: %f\n", static_cast(output)); + + if ((double)output > conf.http_param_threshold) { snort_ml_stats.client_body_alerts++; debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "\n"); @@ -109,13 +104,14 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*) class HttpUriHandler : public DataHandler { public: - HttpUriHandler(SnortML& ml) - : DataHandler(SNORT_ML_NAME), inspector(ml) {} + HttpUriHandler(const SnortMLEngine& eng, const SnortML& ins) + : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {} void handle(DataEvent&, Flow*) override; private: - SnortML& inspector; + const SnortMLEngine& engine; + const SnortML& inspector; }; void HttpUriHandler::handle(DataEvent& de, Flow*) @@ -123,9 +119,7 @@ void HttpUriHandler::handle(DataEvent& de, Flow*) // cppcheck-suppress unreadVariable Profile profile(snort_ml_prof); - libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers(); - SnortMLConfig config = inspector.get_config(); - HttpEvent* he = (HttpEvent*)&de; + HttpEvent* he = reinterpret_cast(&de); int32_t query_len = 0; const char* query = (const char*)he->get_uri_query(query_len); @@ -133,23 +127,23 @@ void HttpUriHandler::handle(DataEvent& de, Flow*) if (!query || query_len <= 0) return; - const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len); - - assert(classifiers); + const SnortMLConfig& conf = inspector.get_config(); - float output = 0.0; + const size_t len = std::min((size_t)conf.uri_depth, (size_t)query_len); - snort_ml_stats.libml_calls++; - - if (!classifiers->run(query, len, output)) + float output = 0; + if (!engine.scan(query, len, output)) return; snort_ml_stats.uri_bytes += len; - debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query); - debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, + "input (query): %.*s\n", (int)len, query); + + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, + "output: %f\n", static_cast(output)); - if ((double)output > config.http_param_threshold) + if ((double)output > conf.http_param_threshold) { snort_ml_stats.uri_alerts++; debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "\n"); @@ -163,25 +157,36 @@ void HttpUriHandler::handle(DataEvent& de, Flow*) void SnortML::show(const SnortConfig*) const { - ConfigLogger::log_limit("uri_depth", config.uri_depth, -1); - ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1); - ConfigLogger::log_value("http_param_threshold", config.http_param_threshold); + ConfigLogger::log_limit("uri_depth", conf.uri_depth, -1); + ConfigLogger::log_limit("client_body_depth", conf.client_body_depth, -1); + ConfigLogger::log_value("http_param_threshold", conf.http_param_threshold); } bool SnortML::configure(SnortConfig* sc) { - if (config.uri_depth != 0) - DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this)); + auto engine = reinterpret_cast( + InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc)); - if (config.client_body_depth != 0) - DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this)); - - if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc)) + if (!engine) { - ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME); + ParseError("snort_ml requires %s to be configured in the global policy.", + SNORT_ML_ENGINE_NAME); + return false; } + if (conf.uri_depth != 0) + { + DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, + new HttpUriHandler(*engine, *this)); + } + + if (conf.client_body_depth != 0) + { + DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, + new HttpBodyHandler(*engine, *this)); + } + return true; } @@ -197,8 +202,8 @@ static void mod_dtor(Module* m) static Inspector* snort_ml_ctor(Module* m) { - SnortMLModule* km = (SnortMLModule*)m; - return new SnortML(km->get_conf()); + const SnortMLModule* mod = reinterpret_cast(m); + return new SnortML(mod->get_config()); } static void snort_ml_dtor(Inspector* p) diff --git a/src/network_inspectors/snort_ml/snort_ml_inspector.h b/src/network_inspectors/snort_ml/snort_ml_inspector.h index f7f97fca9..72e28325f 100644 --- a/src/network_inspectors/snort_ml/snort_ml_inspector.h +++ b/src/network_inspectors/snort_ml/snort_ml_inspector.h @@ -30,16 +30,17 @@ class SnortML : public snort::Inspector { public: - SnortML(const SnortMLConfig& c) : config(c) { } + SnortML(const SnortMLConfig& c) : conf(c) {} void show(const snort::SnortConfig*) const override; void eval(snort::Packet*) override {} bool configure(snort::SnortConfig*) override; - const SnortMLConfig& get_config() - { return config; } + const SnortMLConfig& get_config() const + { return conf; } + private: - SnortMLConfig config; + SnortMLConfig conf; }; #endif diff --git a/src/network_inspectors/snort_ml/snort_ml_module.cc b/src/network_inspectors/snort_ml/snort_ml_module.cc index d0c4cd291..7280bf697 100644 --- a/src/network_inspectors/snort_ml/snort_ml_module.cc +++ b/src/network_inspectors/snort_ml/snort_ml_module.cc @@ -56,7 +56,6 @@ static const PegInfo peg_names[] = { CountType::SUM, "client_body_alerts", "total number of alerts triggered on HTTP client body" }, { CountType::SUM, "uri_bytes", "total number of HTTP URI bytes processed" }, { CountType::SUM, "client_body_bytes", "total number of HTTP client body bytes processed" }, - { CountType::SUM, "libml_calls", "total libml calls" }, { CountType::END, nullptr, nullptr } }; @@ -72,13 +71,17 @@ static const TraceOption snort_ml_trace_options[] = // module //-------------------------------------------------------------------------- -SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {} +SnortMLModule::SnortMLModule() + : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {} bool SnortMLModule::set(const char*, Value& v, SnortConfig*) { - static_assert(std::is_same::value, + static_assert(std::is_same::value, "Field::length maximum value should not exceed uri_depth type range"); - static_assert(std::is_same::value, + + static_assert(std::is_same::value, "Field::length maximum value should not exceed client_body_depth type range"); if (v.is("uri_depth")) @@ -107,7 +110,7 @@ const PegInfo* SnortMLModule::get_pegs() const { return peg_names; } PegCount* SnortMLModule::get_counts() const -{ return (PegCount*)&snort_ml_stats; } +{ return reinterpret_cast(&snort_ml_stats); } ProfileStats* SnortMLModule::get_profile() const { return &snort_ml_prof; } @@ -115,11 +118,7 @@ ProfileStats* SnortMLModule::get_profile() const void SnortMLModule::set_trace(const Trace* trace) const { snort_ml_trace = trace; } +#ifdef DEBUG_MSGS const TraceOption* SnortMLModule::get_trace_options() const -{ -#ifndef DEBUG_MSGS - return nullptr; -#else - return snort_ml_trace_options; +{ return snort_ml_trace_options; } #endif -} diff --git a/src/network_inspectors/snort_ml/snort_ml_module.h b/src/network_inspectors/snort_ml/snort_ml_module.h index b842a06ec..8d680ca9a 100644 --- a/src/network_inspectors/snort_ml/snort_ml_module.h +++ b/src/network_inspectors/snort_ml/snort_ml_module.h @@ -39,7 +39,6 @@ struct SnortMLStats PegCount client_body_alerts; PegCount uri_bytes; PegCount client_body_bytes; - PegCount libml_calls; }; extern THREAD_LOCAL SnortMLStats snort_ml_stats; @@ -48,7 +47,6 @@ extern THREAD_LOCAL const snort::Trace* snort_ml_trace; struct SnortMLConfig { - std::string http_param_model_path; double http_param_threshold; int32_t uri_depth; int32_t client_body_depth; @@ -62,7 +60,7 @@ public: bool set(const char*, snort::Value&, snort::SnortConfig*) override; bool end(const char*, int, snort::SnortConfig*) override; - const SnortMLConfig& get_conf() const + const SnortMLConfig& get_config() const { return conf; } unsigned get_gid() const override @@ -79,7 +77,10 @@ public: snort::ProfileStats* get_profile() const override; void set_trace(const snort::Trace*) const override; + +#ifdef DEBUG_MSGS const snort::TraceOption* get_trace_options() const override; +#endif private: SnortMLConfig conf = {};