find_path(ML_INCLUDE_DIRS
libml.h
- HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+ PATHS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+ NO_DEFAULT_PATH
)
find_library(ML_LIBRARIES
NAMES ml_static
- HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+ PATHS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+ NO_DEFAULT_PATH
)
include(FindPackageHandleStandardArgs)
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2025-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// fnv.h author Brandon Stultz <brastult@cisco.com>
+// based on https://datatracker.ietf.org/doc/html/draft-eastlake-fnv
+
+#ifndef FNV_H
+#define FNV_H
+
+#define FNV_PRIME 0x00000100000001B3
+#define FNV_BASIS 0xCBF29CE484222325
+
+//--------------------------------------------------------------------------
+// FNV-1a Hash
+//--------------------------------------------------------------------------
+
+inline uint64_t fnv1a(const char* buf, const size_t len)
+{
+ uint64_t result = FNV_BASIS;
+
+ for (size_t i = 0; i < len; i++)
+ result = (result ^ static_cast<uint8_t>(buf[i])) * FNV_PRIME;
+
+ return result;
+}
+
+#endif
#include "snort_ml_engine.h"
#include <cassert>
+#include <cstring>
#include <fstream>
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
#include "framework/decode_data.h"
+#include "hash/fnv.h"
#include "helpers/directory.h"
#include "log/messages.h"
#include "main/reload_tuner.h"
using namespace snort;
using namespace std;
-static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr;
+static THREAD_LOCAL SnortMLEngineStats snort_ml_engine_stats;
+static THREAD_LOCAL SnortMLContext* snort_ml_ctx = nullptr;
-static bool build_classifiers(const vector<string>& models,
- libml::BinaryClassifierSet*& set)
+static SnortMLContext* create_context(const SnortMLEngineConfig& conf)
{
- set = new libml::BinaryClassifierSet();
+ SnortMLContext* ctx = new SnortMLContext();
+
+ if (!ctx->classifiers.build(conf.http_param_models))
+ {
+ ErrorMessage("Could not build classifiers.\n");
+ return ctx;
+ }
+
+ if (conf.cache_memcap > 0)
+ {
+ ctx->cache = make_unique<SnortMLCache>(conf.cache_memcap,
+ snort_ml_engine_stats);
+ }
- return set->build(models);
+ return ctx;
}
//--------------------------------------------------------------------------
// module
//--------------------------------------------------------------------------
+static const Parameter filter_param[] =
+{
+ { "filter_pattern", Parameter::PT_STRING, nullptr, nullptr,
+ "pattern that triggers ML classification" },
+ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+static const Parameter ignore_param[] =
+{
+ { "ignore_pattern", Parameter::PT_STRING, nullptr, nullptr,
+ "pattern that skips ML classification" },
+ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
static const Parameter snort_ml_engine_params[] =
{
- { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" },
+ { "http_param_model", Parameter::PT_STRING, nullptr, nullptr,
+ "path to model file(s)" },
+
+ { "http_param_filter", Parameter::PT_LIST, filter_param, nullptr,
+ "list of patterns that trigger ML classification" },
+
+ { "http_param_ignore", Parameter::PT_LIST, ignore_param, nullptr,
+ "list of patterns that skip ML classification" },
+
+ { "cache_memcap", Parameter::PT_INT, "0:maxSZ", "0",
+ "maximum memory for verdict cache in bytes, 0 = disabled" },
+
{ nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
};
-SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+static const PegInfo peg_names[] =
+{
+ LRU_CACHE_LOCAL_PEGS("snort_ml_engine"),
+ { CountType::SUM, "filter_searches", "total filter searches" },
+ { CountType::SUM, "filter_matches", "total filter matches" },
+ { CountType::SUM, "filter_allows", "total filter allows" },
+ { CountType::SUM, "libml_calls", "total libml calls" },
+ { CountType::END, nullptr, nullptr }
+};
+
+SnortMLEngineModule::SnortMLEngineModule()
+ : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+
+bool SnortMLEngineModule::begin(const char* fqn, int, SnortConfig*)
+{
+ if (!strcmp(SNORT_ML_ENGINE_NAME, fqn))
+ conf = {};
+
+ return true;
+}
bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*)
{
if (v.is("http_param_model"))
conf.http_param_model_path = v.get_string();
+ else if (v.is("filter_pattern"))
+ conf.http_param_filters[v.get_string()] = true;
+
+ else if (v.is("ignore_pattern"))
+ {
+ conf.http_param_filters[v.get_string()] = false;
+ conf.has_allow = true;
+ }
+
+ else if (v.is("cache_memcap"))
+ conf.cache_memcap = v.get_size();
+
return true;
}
+const PegInfo* SnortMLEngineModule::get_pegs() const
+{ return peg_names; }
+
+PegCount* SnortMLEngineModule::get_counts() const
+{ return reinterpret_cast<PegCount*>(&snort_ml_engine_stats); }
+
//--------------------------------------------------------------------------
// reload tuner
//--------------------------------------------------------------------------
class SnortMLReloadTuner : public snort::ReloadResourceTuner
{
public:
- explicit SnortMLReloadTuner(const vector<string>& models)
- : http_param_models(models) {}
-
+ explicit SnortMLReloadTuner(const SnortMLEngineConfig& c) : conf(c) {}
~SnortMLReloadTuner() override = default;
const char* name() const override
bool tinit() override
{
- delete classifiers;
-
- if (!build_classifiers(http_param_models, classifiers))
- ErrorMessage("Could not build classifiers.\n");
-
+ delete snort_ml_ctx;
+ snort_ml_ctx = create_context(conf);
return false;
}
{ return true; }
private:
- const vector<string>& http_param_models;
+ const SnortMLEngineConfig& conf;
};
//--------------------------------------------------------------------------
// inspector
//--------------------------------------------------------------------------
-SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c)
+bool SnortMLEngine::configure(SnortConfig*)
{
- if (!read_models() || !validate_models())
+ if (!read_models())
+ return false;
+
+ libml::BinaryClassifierSet classifiers;
+
+ if (!classifiers.build(conf.http_param_models))
+ {
ParseError("Could not build classifiers.");
+ return false;
+ }
+
+ if (!conf.http_param_filters.empty())
+ {
+ mpse = new SearchTool;
+
+ for (auto& f : conf.http_param_filters)
+ mpse->add(f.first.c_str(), f.first.size(), (void*)&f);
+
+ mpse->prep();
+ }
+
+ return true;
}
void SnortMLEngine::show(const SnortConfig*) const
-{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
+{
+ ConfigLogger::log_value("http_param_model", conf.http_param_model_path.c_str());
+ ConfigLogger::log_value("cache_memcap", conf.cache_memcap);
+}
bool SnortMLEngine::read_models()
{
- const char* hint = config.http_param_model_path.c_str();
+ const char* hint = conf.http_param_model_path.c_str();
string path;
if (!get_config_file(hint, path))
}
}
- return !http_param_models.empty();
+ if (conf.http_param_models.empty())
+ {
+ ParseError("snort_ml_engine: no models found");
+ return false;
+ }
+
+ return true;
}
bool SnortMLEngine::read_model(const string& path)
string buffer(size, '\0');
file.read(&buffer[0], streamsize(size));
- http_param_models.push_back(move(buffer));
+ conf.http_param_models.push_back(std::move(buffer));
return true;
}
-bool SnortMLEngine::validate_models()
-{
- libml::BinaryClassifierSet* set = nullptr;
- bool res = build_classifiers(http_param_models, set);
- delete set;
-
- return res;
-}
-
void SnortMLEngine::tinit()
-{ build_classifiers(http_param_models, classifiers); }
+{ snort_ml_ctx = create_context(conf); }
void SnortMLEngine::tterm()
{
- delete classifiers;
- classifiers = nullptr;
+ delete snort_ml_ctx;
+ snort_ml_ctx = nullptr;
}
void SnortMLEngine::install_reload_handler(SnortConfig* sc)
-{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); }
+{ sc->register_reload_handler(new SnortMLReloadTuner(conf)); }
+
+static int filter_match_callback(void* f, void*, int, void* s, void*)
+{
+ auto filter = reinterpret_cast<const pair<string, bool>*>(f);
+ auto search = reinterpret_cast<SnortMLSearch*>(s);
+
+ search->match = true;
+ search->allow |= !filter->second;
+
+ if (search->has_allow && !search->allow)
+ return 0;
+
+ return 1;
+}
+
+bool SnortMLEngine::scan(const char* buf, const size_t len, float& out) const
+{
+ if (!snort_ml_ctx)
+ return false;
+
+ if (mpse)
+ {
+ snort_ml_engine_stats.filter_searches++;
+
+ SnortMLSearch search;
+ search.has_allow = conf.has_allow;
+
+ mpse->find_all(buf, len, filter_match_callback,
+ false, (void*)&search);
-libml::BinaryClassifierSet* SnortMLEngine::get_classifiers()
-{ return classifiers; }
+ if (!search.match)
+ return false;
+
+ snort_ml_engine_stats.filter_matches++;
+
+ if (search.allow)
+ {
+ snort_ml_engine_stats.filter_allows++;
+ return false;
+ }
+ }
+
+ float res = 0;
+ bool is_new = true;
+
+ float& result = (snort_ml_ctx->cache) ?
+ snort_ml_ctx->cache->find_else_create(fnv1a(buf, len), &is_new) : res;
+
+ if (is_new)
+ {
+ snort_ml_engine_stats.libml_calls++;
+
+ if (!snort_ml_ctx->classifiers.run(buf, len, result))
+ return false;
+ }
+
+ out = result;
+ return true;
+}
//--------------------------------------------------------------------------
// api stuff
static Inspector* snort_ml_engine_ctor(Module* m)
{
- SnortMLEngineModule* mod = (SnortMLEngineModule*)m;
+ SnortMLEngineModule* mod = reinterpret_cast<SnortMLEngineModule*>(m);
return new SnortMLEngine(mod->get_config());
}
TEST_CASE("SnortML tuner name", "[snort_ml_module]")
{
- const vector<string> models = { "model" };
- SnortMLReloadTuner tuner(models);
+ SnortMLEngineConfig conf;
+ conf.http_param_models = { "model" };
+ SnortMLReloadTuner tuner(conf);
REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0);
}
#ifndef SNORT_ML_ENGINE_H
#define SNORT_ML_ENGINE_H
-#include "framework/module.h"
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+
#include "framework/inspector.h"
+#include "framework/module.h"
+#include "hash/lru_cache_local.h"
+#include "search_engines/search_tool.h"
#define SNORT_ML_ENGINE_NAME "snort_ml_engine"
#define SNORT_ML_ENGINE_HELP "configure machine learning engine settings"
+// Mock BinaryClassifierSet for tests if LibML is absent
+#ifndef HAVE_LIBML
namespace libml
{
- class BinaryClassifierSet;
+
+class BinaryClassifierSet
+{
+public:
+ bool build(const std::vector<std::string>& models)
+ {
+ if (!models.empty())
+ pattern = models[0];
+
+ return pattern != "error";
+ }
+
+ bool run(const char* ptr, size_t len, float& out)
+ {
+ std::string data(ptr, len);
+ out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
+ return pattern != "fail";
+ }
+
+private:
+ std::string pattern;
+};
+
}
+#endif
+
+struct SnortMLEngineStats : public LruCacheLocalStats
+{
+ PegCount filter_searches;
+ PegCount filter_matches;
+ PegCount filter_allows;
+ PegCount libml_calls;
+};
+
+typedef LruCacheLocal<uint64_t, float, std::hash<uint64_t>> SnortMLCache;
+typedef std::unordered_map<std::string, bool> SnortMLFilterMap;
+
+struct SnortMLContext
+{
+ libml::BinaryClassifierSet classifiers;
+ std::unique_ptr<SnortMLCache> cache;
+};
struct SnortMLEngineConfig
{
std::string http_param_model_path;
+ std::vector<std::string> http_param_models;
+ SnortMLFilterMap http_param_filters;
+ bool has_allow = false;
+ size_t cache_memcap = 0;
+};
+
+struct SnortMLSearch
+{
+ bool match = false;
+ bool allow = false;
+ bool has_allow = false;
};
class SnortMLEngineModule : public snort::Module
public:
SnortMLEngineModule();
+ bool begin(const char*, int, snort::SnortConfig*) override;
bool set(const char*, snort::Value&, snort::SnortConfig*) override;
+ const PegInfo* get_pegs() const override;
+ PegCount* get_counts() const override;
+
Usage get_usage() const override
{ return GLOBAL; }
- const SnortMLEngineConfig& get_config()
- { return conf; }
+ SnortMLEngineConfig get_config()
+ {
+ SnortMLEngineConfig out;
+ std::swap(conf, out);
+ return out;
+ }
private:
SnortMLEngineConfig conf;
class SnortMLEngine : public snort::Inspector
{
public:
- SnortMLEngine(const SnortMLEngineConfig&);
+ SnortMLEngine(SnortMLEngineConfig c) : conf(std::move(c)) {}
+ ~SnortMLEngine() override
+ { delete mpse; }
+ bool configure(snort::SnortConfig*) override;
void show(const snort::SnortConfig*) const override;
- void eval(snort::Packet*) override {}
void tinit() override;
void tterm() override;
void install_reload_handler(snort::SnortConfig*) override;
- static libml::BinaryClassifierSet* get_classifiers();
+ bool scan(const char*, const size_t, float&) const;
private:
bool read_models();
bool read_model(const std::string&);
- bool validate_models();
-
- SnortMLEngineConfig config;
- std::vector<std::string> http_param_models;
-};
-
-// Mock BinaryClassifierSet for tests if LibML is absent.
-// The code below won't be executed if REG_TEST is undefined.
-// Check the plugin type provided in the snort_ml_engine.cc file.
-#ifndef HAVE_LIBML
-namespace libml
-{
-
-class BinaryClassifierSet
-{
-public:
- bool build(const std::vector<std::string>& models)
- {
- if (!models.empty())
- pattern = models[0];
-
- return pattern != "error";
- }
-
- bool run(const char* ptr, size_t len, float& out)
- {
- std::string data(ptr, len);
- out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
- return pattern != "fail";
- }
-
-private:
- std::string pattern;
+ SnortMLEngineConfig conf;
+ snort::SearchTool* mpse = nullptr;
};
-}
-#endif
-
#endif
#include <cassert>
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
#include "detection/detection_engine.h"
#include "log/messages.h"
#include "managers/inspector_manager.h"
class HttpBodyHandler : public DataHandler
{
public:
- HttpBodyHandler(SnortML& ml)
- : DataHandler(SNORT_ML_NAME), inspector(ml) {}
+ HttpBodyHandler(const SnortMLEngine& eng, const SnortML& ins)
+ : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {}
void handle(DataEvent& de, Flow*) override;
private:
- SnortML& inspector;
+ const SnortMLEngine& engine;
+ const SnortML& inspector;
};
void HttpBodyHandler::handle(DataEvent& de, Flow*)
// cppcheck-suppress unreadVariable
Profile profile(snort_ml_prof);
- libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
- SnortMLConfig config = inspector.get_config();
- HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de;
+ HttpRequestBodyEvent* he = reinterpret_cast<HttpRequestBodyEvent*>(&de);
if (he->is_mime())
return;
if (!body || body_len <= 0)
return;
- const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len);
-
- assert(classifiers);
-
- float output = 0.0;
+ const SnortMLConfig& conf = inspector.get_config();
- snort_ml_stats.libml_calls++;
+ const size_t len = std::min((size_t)conf.client_body_depth, (size_t)body_len);
- if (!classifiers->run(body, len, output))
+ float output = 0;
+ if (!engine.scan(body, len, output))
return;
snort_ml_stats.client_body_bytes += len;
- debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
- debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+ "input (body): %.*s\n", (int)len, body);
- if ((double)output > config.http_param_threshold)
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+ "output: %f\n", static_cast<double>(output));
+
+ if ((double)output > conf.http_param_threshold)
{
snort_ml_stats.client_body_alerts++;
debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
class HttpUriHandler : public DataHandler
{
public:
- HttpUriHandler(SnortML& ml)
- : DataHandler(SNORT_ML_NAME), inspector(ml) {}
+ HttpUriHandler(const SnortMLEngine& eng, const SnortML& ins)
+ : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {}
void handle(DataEvent&, Flow*) override;
private:
- SnortML& inspector;
+ const SnortMLEngine& engine;
+ const SnortML& inspector;
};
void HttpUriHandler::handle(DataEvent& de, Flow*)
// cppcheck-suppress unreadVariable
Profile profile(snort_ml_prof);
- libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
- SnortMLConfig config = inspector.get_config();
- HttpEvent* he = (HttpEvent*)&de;
+ HttpEvent* he = reinterpret_cast<HttpEvent*>(&de);
int32_t query_len = 0;
const char* query = (const char*)he->get_uri_query(query_len);
if (!query || query_len <= 0)
return;
- const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len);
-
- assert(classifiers);
+ const SnortMLConfig& conf = inspector.get_config();
- float output = 0.0;
+ const size_t len = std::min((size_t)conf.uri_depth, (size_t)query_len);
- snort_ml_stats.libml_calls++;
-
- if (!classifiers->run(query, len, output))
+ float output = 0;
+ if (!engine.scan(query, len, output))
return;
snort_ml_stats.uri_bytes += len;
- debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
- debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+ "input (query): %.*s\n", (int)len, query);
+
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+ "output: %f\n", static_cast<double>(output));
- if ((double)output > config.http_param_threshold)
+ if ((double)output > conf.http_param_threshold)
{
snort_ml_stats.uri_alerts++;
debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
void SnortML::show(const SnortConfig*) const
{
- ConfigLogger::log_limit("uri_depth", config.uri_depth, -1);
- ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1);
- ConfigLogger::log_value("http_param_threshold", config.http_param_threshold);
+ ConfigLogger::log_limit("uri_depth", conf.uri_depth, -1);
+ ConfigLogger::log_limit("client_body_depth", conf.client_body_depth, -1);
+ ConfigLogger::log_value("http_param_threshold", conf.http_param_threshold);
}
bool SnortML::configure(SnortConfig* sc)
{
- if (config.uri_depth != 0)
- DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this));
+ auto engine = reinterpret_cast<const SnortMLEngine*>(
+ InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc));
- if (config.client_body_depth != 0)
- DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this));
-
- if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc))
+ if (!engine)
{
- ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME);
+ ParseError("snort_ml requires %s to be configured in the global policy.",
+ SNORT_ML_ENGINE_NAME);
+
return false;
}
+ if (conf.uri_depth != 0)
+ {
+ DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER,
+ new HttpUriHandler(*engine, *this));
+ }
+
+ if (conf.client_body_depth != 0)
+ {
+ DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY,
+ new HttpBodyHandler(*engine, *this));
+ }
+
return true;
}
static Inspector* snort_ml_ctor(Module* m)
{
- SnortMLModule* km = (SnortMLModule*)m;
- return new SnortML(km->get_conf());
+ const SnortMLModule* mod = reinterpret_cast<const SnortMLModule*>(m);
+ return new SnortML(mod->get_config());
}
static void snort_ml_dtor(Inspector* p)
class SnortML : public snort::Inspector
{
public:
- SnortML(const SnortMLConfig& c) : config(c) { }
+ SnortML(const SnortMLConfig& c) : conf(c) {}
void show(const snort::SnortConfig*) const override;
void eval(snort::Packet*) override {}
bool configure(snort::SnortConfig*) override;
- const SnortMLConfig& get_config()
- { return config; }
+ const SnortMLConfig& get_config() const
+ { return conf; }
+
private:
- SnortMLConfig config;
+ SnortMLConfig conf;
};
#endif
{ CountType::SUM, "client_body_alerts", "total number of alerts triggered on HTTP client body" },
{ CountType::SUM, "uri_bytes", "total number of HTTP URI bytes processed" },
{ CountType::SUM, "client_body_bytes", "total number of HTTP client body bytes processed" },
- { CountType::SUM, "libml_calls", "total libml calls" },
{ CountType::END, nullptr, nullptr }
};
// module
//--------------------------------------------------------------------------
-SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
+SnortMLModule::SnortMLModule()
+ : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
bool SnortMLModule::set(const char*, Value& v, SnortConfig*)
{
- static_assert(std::is_same<decltype((Field().length())), decltype(conf.uri_depth)>::value,
+ static_assert(std::is_same<decltype((Field().length())),
+ decltype(conf.uri_depth)>::value,
"Field::length maximum value should not exceed uri_depth type range");
- static_assert(std::is_same<decltype((Field().length())), decltype(conf.client_body_depth)>::value,
+
+ static_assert(std::is_same<decltype((Field().length())),
+ decltype(conf.client_body_depth)>::value,
"Field::length maximum value should not exceed client_body_depth type range");
if (v.is("uri_depth"))
{ return peg_names; }
PegCount* SnortMLModule::get_counts() const
-{ return (PegCount*)&snort_ml_stats; }
+{ return reinterpret_cast<PegCount*>(&snort_ml_stats); }
ProfileStats* SnortMLModule::get_profile() const
{ return &snort_ml_prof; }
void SnortMLModule::set_trace(const Trace* trace) const
{ snort_ml_trace = trace; }
+#ifdef DEBUG_MSGS
const TraceOption* SnortMLModule::get_trace_options() const
-{
-#ifndef DEBUG_MSGS
- return nullptr;
-#else
- return snort_ml_trace_options;
+{ return snort_ml_trace_options; }
#endif
-}
PegCount client_body_alerts;
PegCount uri_bytes;
PegCount client_body_bytes;
- PegCount libml_calls;
};
extern THREAD_LOCAL SnortMLStats snort_ml_stats;
struct SnortMLConfig
{
- std::string http_param_model_path;
double http_param_threshold;
int32_t uri_depth;
int32_t client_body_depth;
bool set(const char*, snort::Value&, snort::SnortConfig*) override;
bool end(const char*, int, snort::SnortConfig*) override;
- const SnortMLConfig& get_conf() const
+ const SnortMLConfig& get_config() const
{ return conf; }
unsigned get_gid() const override
snort::ProfileStats* get_profile() const override;
void set_trace(const snort::Trace*) const override;
+
+#ifdef DEBUG_MSGS
const snort::TraceOption* get_trace_options() const override;
+#endif
private:
SnortMLConfig conf = {};