]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4833: snort_ml: add mpse and lru cache
authorBrandon Stultz (brastult) <brastult@cisco.com>
Fri, 24 Oct 2025 09:00:55 +0000 (09:00 +0000)
committerOleksii Shumeiko -X (oshumeik - SOFTSERVE INC at Cisco) <oshumeik@cisco.com>
Fri, 24 Oct 2025 09:00:55 +0000 (09:00 +0000)
Merge in SNORT/snort3 from ~BRASTULT/snort3:snort_ml_pipeline to master

Squashed commit of the following:

commit 1f51dd1bee92a4995d960561b59a72e1a8903b53
Author: Brandon Stultz <brastult@cisco.com>
Date:   Fri Jul 25 13:46:00 2025 -0400

    build: only enable libml for supported versions

commit 47a789fc3b637f95b11ba0b154af53440ed5b2f2
Author: Brandon Stultz <brastult@cisco.com>
Date:   Fri Jul 25 13:32:01 2025 -0400

    snort_ml: add mpse and lru cache

commit 7c74729080cc2f1095dbbeee8e98bbbda00accf9
Author: Brandon Stultz <brastult@cisco.com>
Date:   Fri Sep 5 17:00:03 2025 -0400

    hash: add FNV-1a hash

cmake/FindML.cmake
src/hash/fnv.h [new file with mode: 0644]
src/network_inspectors/snort_ml/snort_ml_engine.cc
src/network_inspectors/snort_ml/snort_ml_engine.h
src/network_inspectors/snort_ml/snort_ml_inspector.cc
src/network_inspectors/snort_ml/snort_ml_inspector.h
src/network_inspectors/snort_ml/snort_ml_module.cc
src/network_inspectors/snort_ml/snort_ml_module.h

index c30553c959448a458b43f1cc87734cfff29ca5fb..15943ffadfe97dc290d842bfc48c9b0a0fb06c7c 100644 (file)
@@ -3,12 +3,14 @@ pkg_check_modules(PC_ML libml_static>=2.0.0)
 
 find_path(ML_INCLUDE_DIRS
     libml.h
-    HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+    PATHS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+    NO_DEFAULT_PATH
 )
 
 find_library(ML_LIBRARIES
     NAMES ml_static
-    HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+    PATHS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+    NO_DEFAULT_PATH
 )
 
 include(FindPackageHandleStandardArgs)
diff --git a/src/hash/fnv.h b/src/hash/fnv.h
new file mode 100644 (file)
index 0000000..48c8591
--- /dev/null
@@ -0,0 +1,41 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2025-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// fnv.h author Brandon Stultz <brastult@cisco.com>
+// based on https://datatracker.ietf.org/doc/html/draft-eastlake-fnv
+
+#ifndef FNV_H
+#define FNV_H
+
+#define FNV_PRIME 0x00000100000001B3
+#define FNV_BASIS 0xCBF29CE484222325
+
+//--------------------------------------------------------------------------
+// FNV-1a Hash
+//--------------------------------------------------------------------------
+
+inline uint64_t fnv1a(const char* buf, const size_t len)
+{
+    uint64_t result = FNV_BASIS;
+
+    for (size_t i = 0; i < len; i++)
+        result = (result ^ static_cast<uint8_t>(buf[i])) * FNV_PRIME;
+
+    return result;
+}
+
+#endif
index 6b3c63c96e5261093a7b4647f0a0e4e3a2638fc7..8d69174a65ca3ceb1924db2cf0b23be68606ff54 100644 (file)
 #include "snort_ml_engine.h"
 
 #include <cassert>
+#include <cstring>
 #include <fstream>
 
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
 #include "framework/decode_data.h"
+#include "hash/fnv.h"
 #include "helpers/directory.h"
 #include "log/messages.h"
 #include "main/reload_tuner.h"
 using namespace snort;
 using namespace std;
 
-static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr;
+static THREAD_LOCAL SnortMLEngineStats snort_ml_engine_stats;
+static THREAD_LOCAL SnortMLContext* snort_ml_ctx = nullptr;
 
-static bool build_classifiers(const vector<string>& models,
-    libml::BinaryClassifierSet*& set)
+static SnortMLContext* create_context(const SnortMLEngineConfig& conf)
 {
-    set = new libml::BinaryClassifierSet();
+    SnortMLContext* ctx = new SnortMLContext();
+
+    if (!ctx->classifiers.build(conf.http_param_models))
+    {
+        ErrorMessage("Could not build classifiers.\n");
+        return ctx;
+    }
+
+    if (conf.cache_memcap > 0)
+    {
+        ctx->cache = make_unique<SnortMLCache>(conf.cache_memcap,
+            snort_ml_engine_stats);
+    }
 
-    return set->build(models);
+    return ctx;
 }
 
 //--------------------------------------------------------------------------
 // module
 //--------------------------------------------------------------------------
 
+static const Parameter filter_param[] =
+{
+    { "filter_pattern", Parameter::PT_STRING, nullptr, nullptr,
+      "pattern that triggers ML classification" },
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+static const Parameter ignore_param[] =
+{
+    { "ignore_pattern", Parameter::PT_STRING, nullptr, nullptr,
+      "pattern that skips ML classification" },
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
 static const Parameter snort_ml_engine_params[] =
 {
-    { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" },
+    { "http_param_model", Parameter::PT_STRING, nullptr, nullptr,
+      "path to model file(s)" },
+
+    { "http_param_filter", Parameter::PT_LIST, filter_param, nullptr,
+      "list of patterns that trigger ML classification" },
+
+    { "http_param_ignore", Parameter::PT_LIST, ignore_param, nullptr,
+      "list of patterns that skip ML classification" },
+
+    { "cache_memcap", Parameter::PT_INT, "0:maxSZ", "0",
+      "maximum memory for verdict cache in bytes, 0 = disabled" },
+
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
-SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+static const PegInfo peg_names[] =
+{
+    LRU_CACHE_LOCAL_PEGS("snort_ml_engine"),
+    { CountType::SUM, "filter_searches", "total filter searches" },
+    { CountType::SUM, "filter_matches", "total filter matches" },
+    { CountType::SUM, "filter_allows", "total filter allows" },
+    { CountType::SUM, "libml_calls", "total libml calls" },
+    { CountType::END, nullptr, nullptr }
+};
+
+SnortMLEngineModule::SnortMLEngineModule()
+    : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+
+bool SnortMLEngineModule::begin(const char* fqn, int, SnortConfig*)
+{
+    if (!strcmp(SNORT_ML_ENGINE_NAME, fqn))
+        conf = {};
+
+    return true;
+}
 
 bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*)
 {
     if (v.is("http_param_model"))
         conf.http_param_model_path = v.get_string();
 
+    else if (v.is("filter_pattern"))
+        conf.http_param_filters[v.get_string()] = true;
+
+    else if (v.is("ignore_pattern"))
+    {
+        conf.http_param_filters[v.get_string()] = false;
+        conf.has_allow = true;
+    }
+
+    else if (v.is("cache_memcap"))
+        conf.cache_memcap = v.get_size();
+
     return true;
 }
 
+const PegInfo* SnortMLEngineModule::get_pegs() const
+{ return peg_names; }
+
+PegCount* SnortMLEngineModule::get_counts() const
+{ return reinterpret_cast<PegCount*>(&snort_ml_engine_stats); }
+
 //--------------------------------------------------------------------------
 // reload tuner
 //--------------------------------------------------------------------------
@@ -80,9 +152,7 @@ bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*)
 class SnortMLReloadTuner : public snort::ReloadResourceTuner
 {
 public:
-    explicit SnortMLReloadTuner(const vector<string>& models)
-        : http_param_models(models) {}
-
+    explicit SnortMLReloadTuner(const SnortMLEngineConfig& c) : conf(c) {}
     ~SnortMLReloadTuner() override = default;
 
     const char* name() const override
@@ -90,11 +160,8 @@ public:
 
     bool tinit() override
     {
-        delete classifiers;
-
-        if (!build_classifiers(http_param_models, classifiers))
-            ErrorMessage("Could not build classifiers.\n");
-
+        delete snort_ml_ctx;
+        snort_ml_ctx = create_context(conf);
         return false;
     }
 
@@ -105,25 +172,48 @@ public:
     { return true; }
 
 private:
-    const vector<string>& http_param_models;
+    const SnortMLEngineConfig& conf;
 };
 
 //--------------------------------------------------------------------------
 // inspector
 //--------------------------------------------------------------------------
 
-SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c)
+bool SnortMLEngine::configure(SnortConfig*)
 {
-    if (!read_models() || !validate_models())
+    if (!read_models())
+        return false;
+
+    libml::BinaryClassifierSet classifiers;
+
+    if (!classifiers.build(conf.http_param_models))
+    {
         ParseError("Could not build classifiers.");
+        return false;
+    }
+
+    if (!conf.http_param_filters.empty())
+    {
+        mpse = new SearchTool;
+
+        for (auto& f : conf.http_param_filters)
+            mpse->add(f.first.c_str(), f.first.size(), (void*)&f);
+
+        mpse->prep();
+    }
+
+    return true;
 }
 
 void SnortMLEngine::show(const SnortConfig*) const
-{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
+{
+    ConfigLogger::log_value("http_param_model", conf.http_param_model_path.c_str());
+    ConfigLogger::log_value("cache_memcap", conf.cache_memcap);
+}
 
 bool SnortMLEngine::read_models()
 {
-    const char* hint = config.http_param_model_path.c_str();
+    const char* hint = conf.http_param_model_path.c_str();
     string path;
 
     if (!get_config_file(hint, path))
@@ -160,7 +250,13 @@ bool SnortMLEngine::read_models()
         }
     }
 
-    return !http_param_models.empty();
+    if (conf.http_param_models.empty())
+    {
+        ParseError("snort_ml_engine: no models found");
+        return false;
+    }
+
+    return true;
 }
 
 bool SnortMLEngine::read_model(const string& path)
@@ -178,33 +274,80 @@ bool SnortMLEngine::read_model(const string& path)
     string buffer(size, '\0');
     file.read(&buffer[0], streamsize(size));
 
-    http_param_models.push_back(move(buffer));
+    conf.http_param_models.push_back(std::move(buffer));
     return true;
 }
 
-bool SnortMLEngine::validate_models()
-{
-    libml::BinaryClassifierSet* set = nullptr;
-    bool res = build_classifiers(http_param_models, set);
-    delete set;
-
-    return res;
-}
-
 void SnortMLEngine::tinit()
-{ build_classifiers(http_param_models, classifiers); }
+{ snort_ml_ctx = create_context(conf); }
 
 void SnortMLEngine::tterm()
 {
-    delete classifiers;
-    classifiers = nullptr;
+    delete snort_ml_ctx;
+    snort_ml_ctx = nullptr;
 }
 
 void SnortMLEngine::install_reload_handler(SnortConfig* sc)
-{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); }
+{ sc->register_reload_handler(new SnortMLReloadTuner(conf)); }
+
+static int filter_match_callback(void* f, void*, int, void* s, void*)
+{
+    auto filter = reinterpret_cast<const pair<string, bool>*>(f);
+    auto search = reinterpret_cast<SnortMLSearch*>(s);
+
+    search->match = true;
+    search->allow |= !filter->second;
+
+    if (search->has_allow && !search->allow)
+        return 0;
+
+    return 1;
+}
+
+bool SnortMLEngine::scan(const char* buf, const size_t len, float& out) const
+{
+    if (!snort_ml_ctx)
+        return false;
+
+    if (mpse)
+    {
+        snort_ml_engine_stats.filter_searches++;
+
+        SnortMLSearch search;
+        search.has_allow = conf.has_allow;
+
+        mpse->find_all(buf, len, filter_match_callback,
+            false, (void*)&search);
 
-libml::BinaryClassifierSet* SnortMLEngine::get_classifiers()
-{ return classifiers; }
+        if (!search.match)
+            return false;
+
+        snort_ml_engine_stats.filter_matches++;
+
+        if (search.allow)
+        {
+            snort_ml_engine_stats.filter_allows++;
+            return false;
+        }
+    }
+
+    float res = 0;
+    bool is_new = true;
+
+    float& result = (snort_ml_ctx->cache) ?
+        snort_ml_ctx->cache->find_else_create(fnv1a(buf, len), &is_new) : res;
+
+    if (is_new)
+    {
+        snort_ml_engine_stats.libml_calls++;
+
+        if (!snort_ml_ctx->classifiers.run(buf, len, result))
+            return false;
+    }
+
+    out = result;
+    return true;
+}
 
 //--------------------------------------------------------------------------
 // api stuff
@@ -218,7 +361,7 @@ static void mod_dtor(Module* m)
 
 static Inspector* snort_ml_engine_ctor(Module* m)
 {
-    SnortMLEngineModule* mod = (SnortMLEngineModule*)m;
+    SnortMLEngineModule* mod = reinterpret_cast<SnortMLEngineModule*>(m);
     return new SnortMLEngine(mod->get_config());
 }
 
@@ -274,8 +417,9 @@ const BaseApi* nin_snort_ml_engine[] =
 
 TEST_CASE("SnortML tuner name", "[snort_ml_module]")
 {
-    const vector<string> models = { "model" };
-    SnortMLReloadTuner tuner(models);
+    SnortMLEngineConfig conf;
+    conf.http_param_models = { "model" };
+    SnortMLReloadTuner tuner(conf);
 
     REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0);
 }
index 0d77d16642f6d004a09e140cc564aa12d47023f5..beeae3389e82cba5f3b4436d4ba81a7bd2d9c2d5 100644 (file)
 #ifndef SNORT_ML_ENGINE_H
 #define SNORT_ML_ENGINE_H
 
-#include "framework/module.h"
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+
 #include "framework/inspector.h"
+#include "framework/module.h"
+#include "hash/lru_cache_local.h"
+#include "search_engines/search_tool.h"
 
 #define SNORT_ML_ENGINE_NAME "snort_ml_engine"
 #define SNORT_ML_ENGINE_HELP "configure machine learning engine settings"
 
+// Mock BinaryClassifierSet for tests if LibML is absent
+#ifndef HAVE_LIBML
 namespace libml
 {
-    class BinaryClassifierSet;
+
+class BinaryClassifierSet
+{
+public:
+    bool build(const std::vector<std::string>& models)
+    {
+        if (!models.empty())
+            pattern = models[0];
+
+        return pattern != "error";
+    }
+
+    bool run(const char* ptr, size_t len, float& out)
+    {
+        std::string data(ptr, len);
+        out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
+        return pattern != "fail";
+    }
+
+private:
+    std::string pattern;
+};
+
 }
+#endif
+
+struct SnortMLEngineStats : public LruCacheLocalStats
+{
+    PegCount filter_searches;
+    PegCount filter_matches;
+    PegCount filter_allows;
+    PegCount libml_calls;
+};
+
+typedef LruCacheLocal<uint64_t, float, std::hash<uint64_t>> SnortMLCache;
+typedef std::unordered_map<std::string, bool> SnortMLFilterMap;
+
+struct SnortMLContext
+{
+    libml::BinaryClassifierSet classifiers;
+    std::unique_ptr<SnortMLCache> cache;
+};
 
 struct SnortMLEngineConfig
 {
     std::string http_param_model_path;
+    std::vector<std::string> http_param_models;
+    SnortMLFilterMap http_param_filters;
+    bool has_allow = false;
+    size_t cache_memcap = 0;
+};
+
+struct SnortMLSearch
+{
+    bool match = false;
+    bool allow = false;
+    bool has_allow = false;
 };
 
 class SnortMLEngineModule : public snort::Module
@@ -42,13 +105,21 @@ class SnortMLEngineModule : public snort::Module
 public:
     SnortMLEngineModule();
 
+    bool begin(const char*, int, snort::SnortConfig*) override;
     bool set(const char*, snort::Value&, snort::SnortConfig*) override;
 
+    const PegInfo* get_pegs() const override;
+    PegCount* get_counts() const override;
+
     Usage get_usage() const override
     { return GLOBAL; }
 
-    const SnortMLEngineConfig& get_config()
-    { return conf; }
+    SnortMLEngineConfig get_config()
+    {
+        SnortMLEngineConfig out;
+        std::swap(conf, out);
+        return out;
+    }
 
 private:
     SnortMLEngineConfig conf;
@@ -57,58 +128,26 @@ private:
 class SnortMLEngine : public snort::Inspector
 {
 public:
-    SnortMLEngine(const SnortMLEngineConfig&);
+    SnortMLEngine(SnortMLEngineConfig c) : conf(std::move(c)) {}
+    ~SnortMLEngine() override
+    { delete mpse; }
 
+    bool configure(snort::SnortConfig*) override;
     void show(const snort::SnortConfig*) const override;
-    void eval(snort::Packet*) override {}
 
     void tinit() override;
     void tterm() override;
 
     void install_reload_handler(snort::SnortConfig*) override;
 
-    static libml::BinaryClassifierSet* get_classifiers();
+    bool scan(const char*, const size_t, float&) const;
 
 private:
     bool read_models();
     bool read_model(const std::string&);
 
-    bool validate_models();
-
-    SnortMLEngineConfig config;
-    std::vector<std::string> http_param_models;
-};
-
-// Mock BinaryClassifierSet for tests if LibML is absent.
-// The code below won't be executed if REG_TEST is undefined.
-// Check the plugin type provided in the snort_ml_engine.cc file.
-#ifndef HAVE_LIBML
-namespace libml
-{
-
-class BinaryClassifierSet
-{
-public:
-    bool build(const std::vector<std::string>& models)
-    {
-        if (!models.empty())
-            pattern = models[0];
-
-        return pattern != "error";
-    }
-
-    bool run(const char* ptr, size_t len, float& out)
-    {
-        std::string data(ptr, len);
-        out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
-        return pattern != "fail";
-    }
-
-private:
-    std::string pattern;
+    SnortMLEngineConfig conf;
+    snort::SearchTool* mpse = nullptr;
 };
 
-}
-#endif
-
 #endif
index 9d27a5817a93cc9f1517180e807d5ff7e9c0cda6..64a9f62a700812dd7988cf389fdd35dd501b8f3c 100644 (file)
 
 #include <cassert>
 
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
 #include "detection/detection_engine.h"
 #include "log/messages.h"
 #include "managers/inspector_manager.h"
@@ -51,13 +47,14 @@ THREAD_LOCAL ProfileStats snort_ml_prof;
 class HttpBodyHandler : public DataHandler
 {
 public:
-    HttpBodyHandler(SnortML& ml)
-        : DataHandler(SNORT_ML_NAME), inspector(ml) {}
+    HttpBodyHandler(const SnortMLEngine& eng, const SnortML& ins)
+        : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {}
 
     void handle(DataEvent& de, Flow*) override;
 
 private:
-    SnortML& inspector;
+    const SnortMLEngine& engine;
+    const SnortML& inspector;
 };
 
 void HttpBodyHandler::handle(DataEvent& de, Flow*)
@@ -65,9 +62,7 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*)
     // cppcheck-suppress unreadVariable
     Profile profile(snort_ml_prof);
 
-    libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
-    SnortMLConfig config = inspector.get_config();
-    HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de;
+    HttpRequestBodyEvent* he = reinterpret_cast<HttpRequestBodyEvent*>(&de);
 
     if (he->is_mime())
         return;
@@ -78,23 +73,23 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*)
     if (!body || body_len <= 0)
         return;
 
-    const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len);
-
-    assert(classifiers);
-
-    float output = 0.0;
+    const SnortMLConfig& conf = inspector.get_config();
 
-    snort_ml_stats.libml_calls++;
+    const size_t len = std::min((size_t)conf.client_body_depth, (size_t)body_len);
 
-    if (!classifiers->run(body, len, output))
+    float output = 0;
+    if (!engine.scan(body, len, output))
         return;
 
     snort_ml_stats.client_body_bytes += len;
 
-    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
-    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+        "input (body): %.*s\n", (int)len, body);
 
-    if ((double)output > config.http_param_threshold)
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+        "output: %f\n", static_cast<double>(output));
+
+    if ((double)output > conf.http_param_threshold)
     {
         snort_ml_stats.client_body_alerts++;
         debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
@@ -109,13 +104,14 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*)
 class HttpUriHandler : public DataHandler
 {
 public:
-    HttpUriHandler(SnortML& ml)
-        : DataHandler(SNORT_ML_NAME), inspector(ml) {}
+    HttpUriHandler(const SnortMLEngine& eng, const SnortML& ins)
+        : DataHandler(SNORT_ML_NAME), engine(eng), inspector(ins) {}
 
     void handle(DataEvent&, Flow*) override;
 
 private:
-    SnortML& inspector;
+    const SnortMLEngine& engine;
+    const SnortML& inspector;
 };
 
 void HttpUriHandler::handle(DataEvent& de, Flow*)
@@ -123,9 +119,7 @@ void HttpUriHandler::handle(DataEvent& de, Flow*)
     // cppcheck-suppress unreadVariable
     Profile profile(snort_ml_prof);
 
-    libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
-    SnortMLConfig config = inspector.get_config();
-    HttpEvent* he = (HttpEvent*)&de;
+    HttpEvent* he = reinterpret_cast<HttpEvent*>(&de);
 
     int32_t query_len = 0;
     const char* query = (const char*)he->get_uri_query(query_len);
@@ -133,23 +127,23 @@ void HttpUriHandler::handle(DataEvent& de, Flow*)
     if (!query || query_len <= 0)
         return;
 
-    const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len);
-
-    assert(classifiers);
+    const SnortMLConfig& conf = inspector.get_config();
 
-    float output = 0.0;
+    const size_t len = std::min((size_t)conf.uri_depth, (size_t)query_len);
 
-    snort_ml_stats.libml_calls++;
-
-    if (!classifiers->run(query, len, output))
+    float output = 0;
+    if (!engine.scan(query, len, output))
         return;
 
     snort_ml_stats.uri_bytes += len;
 
-    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
-    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+        "input (query): %.*s\n", (int)len, query);
+
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr,
+        "output: %f\n", static_cast<double>(output));
 
-    if ((double)output > config.http_param_threshold)
+    if ((double)output > conf.http_param_threshold)
     {
         snort_ml_stats.uri_alerts++;
         debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
@@ -163,25 +157,36 @@ void HttpUriHandler::handle(DataEvent& de, Flow*)
 
 void SnortML::show(const SnortConfig*) const
 {
-    ConfigLogger::log_limit("uri_depth", config.uri_depth, -1);
-    ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1);
-    ConfigLogger::log_value("http_param_threshold", config.http_param_threshold);
+    ConfigLogger::log_limit("uri_depth", conf.uri_depth, -1);
+    ConfigLogger::log_limit("client_body_depth", conf.client_body_depth, -1);
+    ConfigLogger::log_value("http_param_threshold", conf.http_param_threshold);
 }
 
 bool SnortML::configure(SnortConfig* sc)
 {
-    if (config.uri_depth != 0)
-        DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this));
+    auto engine = reinterpret_cast<const SnortMLEngine*>(
+        InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc));
 
-    if (config.client_body_depth != 0)
-        DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this));
-
-    if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc))
+    if (!engine)
     {
-        ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME);
+        ParseError("snort_ml requires %s to be configured in the global policy.",
+            SNORT_ML_ENGINE_NAME);
+
         return false;
     }
 
+    if (conf.uri_depth != 0)
+    {
+        DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER,
+            new HttpUriHandler(*engine, *this));
+    }
+
+    if (conf.client_body_depth != 0)
+    {
+        DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY,
+            new HttpBodyHandler(*engine, *this));
+    }
+
     return true;
 }
 
@@ -197,8 +202,8 @@ static void mod_dtor(Module* m)
 
 static Inspector* snort_ml_ctor(Module* m)
 {
-    SnortMLModule* km = (SnortMLModule*)m;
-    return new SnortML(km->get_conf());
+    const SnortMLModule* mod = reinterpret_cast<const SnortMLModule*>(m);
+    return new SnortML(mod->get_config());
 }
 
 static void snort_ml_dtor(Inspector* p)
index f7f97fca9214ab42bd75ed933d7074ebe778e24b..72e28325f7c216b6ebf022dcc608d02a473d4009 100644 (file)
 class SnortML : public snort::Inspector
 {
 public:
-    SnortML(const SnortMLConfig& c) : config(c) { }
+    SnortML(const SnortMLConfig& c) : conf(c) {}
 
     void show(const snort::SnortConfig*) const override;
     void eval(snort::Packet*) override {}
     bool configure(snort::SnortConfig*) override;
 
-    const SnortMLConfig& get_config()
-    { return config; }
+    const SnortMLConfig& get_config() const
+    { return conf; }
+
 private:
-    SnortMLConfig config;
+    SnortMLConfig conf;
 };
 
 #endif
index d0c4cd2917e925812af31eca7d2aaad83cbb333b..7280bf6973efeef196516ea5fd6fc567aaa5292f 100644 (file)
@@ -56,7 +56,6 @@ static const PegInfo peg_names[] =
     { CountType::SUM, "client_body_alerts", "total number of alerts triggered on HTTP client body" },
     { CountType::SUM, "uri_bytes", "total number of HTTP URI bytes processed" },
     { CountType::SUM, "client_body_bytes", "total number of HTTP client body bytes processed" },
-    { CountType::SUM, "libml_calls", "total libml calls" },
     { CountType::END, nullptr, nullptr }
 };
 
@@ -72,13 +71,17 @@ static const TraceOption snort_ml_trace_options[] =
 // module
 //--------------------------------------------------------------------------
 
-SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
+SnortMLModule::SnortMLModule()
+    : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
 
 bool SnortMLModule::set(const char*, Value& v, SnortConfig*)
 {
-    static_assert(std::is_same<decltype((Field().length())), decltype(conf.uri_depth)>::value,
+    static_assert(std::is_same<decltype((Field().length())),
+        decltype(conf.uri_depth)>::value,
         "Field::length maximum value should not exceed uri_depth type range");
-    static_assert(std::is_same<decltype((Field().length())), decltype(conf.client_body_depth)>::value,
+
+    static_assert(std::is_same<decltype((Field().length())),
+        decltype(conf.client_body_depth)>::value,
         "Field::length maximum value should not exceed client_body_depth type range");
 
     if (v.is("uri_depth"))
@@ -107,7 +110,7 @@ const PegInfo* SnortMLModule::get_pegs() const
 { return peg_names; }
 
 PegCount* SnortMLModule::get_counts() const
-{ return (PegCount*)&snort_ml_stats; }
+{ return reinterpret_cast<PegCount*>(&snort_ml_stats); }
 
 ProfileStats* SnortMLModule::get_profile() const
 { return &snort_ml_prof; }
@@ -115,11 +118,7 @@ ProfileStats* SnortMLModule::get_profile() const
 void SnortMLModule::set_trace(const Trace* trace) const
 { snort_ml_trace = trace; }
 
+#ifdef DEBUG_MSGS
 const TraceOption* SnortMLModule::get_trace_options() const
-{
-#ifndef DEBUG_MSGS
-    return nullptr;
-#else
-    return snort_ml_trace_options;
+{ return snort_ml_trace_options; }
 #endif
-}
index b842a06ecaf34de8025a3391e5f9f93957dca3a8..8d680ca9a17415770a5f8fd3999c7d6de566b3e2 100644 (file)
@@ -39,7 +39,6 @@ struct SnortMLStats
     PegCount client_body_alerts;
     PegCount uri_bytes;
     PegCount client_body_bytes;
-    PegCount libml_calls;
 };
 
 extern THREAD_LOCAL SnortMLStats snort_ml_stats;
@@ -48,7 +47,6 @@ extern THREAD_LOCAL const snort::Trace* snort_ml_trace;
 
 struct SnortMLConfig
 {
-    std::string http_param_model_path;
     double http_param_threshold;
     int32_t uri_depth;
     int32_t client_body_depth;
@@ -62,7 +60,7 @@ public:
     bool set(const char*, snort::Value&, snort::SnortConfig*) override;
     bool end(const char*, int, snort::SnortConfig*) override;
 
-    const SnortMLConfig& get_conf() const
+    const SnortMLConfig& get_config() const
     { return conf; }
 
     unsigned get_gid() const override
@@ -79,7 +77,10 @@ public:
     snort::ProfileStats* get_profile() const override;
 
     void set_trace(const snort::Trace*) const override;
+
+#ifdef DEBUG_MSGS
     const snort::TraceOption* get_trace_options() const override;
+#endif
 
 private:
     SnortMLConfig conf = {};