]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4595: snort_ml: build models into a BinaryClassifierSet
authorBrandon Stultz (brastult) <brastult@cisco.com>
Tue, 11 Feb 2025 09:28:46 +0000 (09:28 +0000)
committerOleksii Shumeiko -X (oshumeik - SOFTSERVE INC at Cisco) <oshumeik@cisco.com>
Tue, 11 Feb 2025 09:28:46 +0000 (09:28 +0000)
Merge in SNORT/snort3 from ~BRASTULT/snort3:snort_ml to master

Squashed commit of the following:

commit e4f35d63b7bc2fa38176408466afe8576d0f77f0
Author: Brandon Stultz <brastult@cisco.com>
Date:   Fri Jan 31 02:43:25 2025 -0500

    snort_ml: build models into a BinaryClassifierSet

commit 7ac7827b65192d6319893498585b48e0c7809e1b
Author: Brandon Stultz <brastult@cisco.com>
Date:   Fri Jan 31 01:16:00 2025 -0500

    utils: add is_directory_path

commit e3897fe6bf08d2fba2406f612b4bf3b31e07cfea
Author: Brandon Stultz <brastult@cisco.com>
Date:   Thu Jan 30 11:57:53 2025 -0500

    network_inspectors: rename kaizen to snort_ml

18 files changed:
cmake/FindML.cmake
src/main/process.cc
src/main/shell.cc
src/network_inspectors/CMakeLists.txt
src/network_inspectors/dev_notes.txt
src/network_inspectors/kaizen/CMakeLists.txt [deleted file]
src/network_inspectors/kaizen/kaizen_engine.cc [deleted file]
src/network_inspectors/network_inspectors.cc
src/network_inspectors/snort_ml/CMakeLists.txt [new file with mode: 0644]
src/network_inspectors/snort_ml/dev_notes.txt [moved from src/network_inspectors/kaizen/dev_notes.txt with 100% similarity]
src/network_inspectors/snort_ml/snort_ml_engine.cc [new file with mode: 0644]
src/network_inspectors/snort_ml/snort_ml_engine.h [moved from src/network_inspectors/kaizen/kaizen_engine.h with 54% similarity]
src/network_inspectors/snort_ml/snort_ml_inspector.cc [moved from src/network_inspectors/kaizen/kaizen_inspector.cc with 64% similarity]
src/network_inspectors/snort_ml/snort_ml_inspector.h [moved from src/network_inspectors/kaizen/kaizen_inspector.h with 76% similarity]
src/network_inspectors/snort_ml/snort_ml_module.cc [moved from src/network_inspectors/kaizen/kaizen_module.cc with 74% similarity]
src/network_inspectors/snort_ml/snort_ml_module.h [moved from src/network_inspectors/kaizen/kaizen_module.h with 74% similarity]
src/utils/util.cc
src/utils/util.h

index 0282bcbefdd36a823f0816885bf4ae5814453d23..c30553c959448a458b43f1cc87734cfff29ca5fb 100644 (file)
@@ -1,12 +1,22 @@
-find_library(ML_LIBRARIES NAMES ml_static HINTS ${ML_LIBRARIES_DIR_HINT})
-find_path(ML_INCLUDE_DIRS libml.h HINTS ${ML_INCLUDE_DIR_HINT})
+find_package(PkgConfig)
+pkg_check_modules(PC_ML libml_static>=2.0.0)
+
+find_path(ML_INCLUDE_DIRS
+    libml.h
+    HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+)
+
+find_library(ML_LIBRARIES
+    NAMES ml_static
+    HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+)
 
 include(FindPackageHandleStandardArgs)
 
 find_package_handle_standard_args(ML
     DEFAULT_MSG
-    ML_LIBRARIES
     ML_INCLUDE_DIRS
+    ML_LIBRARIES
 )
 
 if (ML_FOUND AND NOT USE_LIBML_MOCK)
index 2b1f94fe4e4a351cfce3bd6122a11ae14b9bd703..f361a5962515a40b1e850d99278de87d0b86a1be 100644 (file)
 #include <jemalloc/jemalloc.h>
 #endif
 
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
 #ifdef HAVE_LIBUNWIND
 #define UNW_LOCAL_ONLY
 #include <libunwind.h>
@@ -706,6 +710,9 @@ int DisplayBanner()
     size_t sz = sizeof(jv);
     mallctl("version", &jv,  &sz, NULL, 0);
     LogMessage("           Using Jemalloc version %s\n", jv);
+#endif
+#ifdef HAVE_LIBML
+    LogMessage("           Using LibML version %s\n", libml::version());
 #endif
     LogMessage("           Using %s\n", pcap_lib_version());
     LogMessage("           Using LuaJIT version %s\n", ljv);
index bb7b7134cee24ecd30fb34007d939760e30de7e3..21d1b753783a86b221cee284962dc00937f9a1e3 100644 (file)
@@ -171,7 +171,7 @@ static void install_dependencies_strings(Shell* sh, lua_State* L)
     vs.push_back(lzma_version_string());
 #endif
 #ifdef HAVE_LIBML
-    vs.push_back(libml_version());
+    vs.push_back(libml::version());
 #endif
 
     lua_createtable(L, 0, vs.size());
index fd8694b757bc271f448892cdb5e290b3eefd8e33..81ea889bd9de9ccab9d0bc973bd0b8a08c73a3f9 100644 (file)
@@ -5,7 +5,7 @@ add_subdirectory(binder)
 add_subdirectory(extractor)
 
 if ( HAVE_LIBML OR USE_LIBML_MOCK )
-    add_subdirectory(kaizen)
+    add_subdirectory(snort_ml)
 endif()
 
 add_subdirectory(normalize)
@@ -24,8 +24,8 @@ if(STATIC_INSPECTORS)
 endif()
 
 if ( HAVE_LIBML OR USE_LIBML_MOCK )
-    set(KAIZEN_STATIC_OBJ
-        $<TARGET_OBJECTS:kaizen>
+    set(SNORT_ML_STATIC_OBJ
+        $<TARGET_OBJECTS:snort_ml>
     )
 endif()
 
@@ -33,7 +33,7 @@ set(STATIC_NETWORK_INSPECTOR_PLUGINS
     $<TARGET_OBJECTS:appid>
     $<TARGET_OBJECTS:binder>
     $<TARGET_OBJECTS:extractor>
-    ${KAIZEN_STATIC_OBJ}
+    ${SNORT_ML_STATIC_OBJ}
     $<TARGET_OBJECTS:normalize>
     $<TARGET_OBJECTS:port_scan>
     $<TARGET_OBJECTS:reputation>
index fabcd45f882ae0cf36c04eec15154633f294626e..a3cf640b20e51e465918cd3a8278061fb79bf2a2 100644 (file)
@@ -20,9 +20,9 @@ normalizations.
 
 packet_capture - A tool for dumping the wire packets that Snort receives.
 
-kaizen - Machine learning based exploit detector capable of detecting novel
-attacks fitting known vulnerability types. Kaizen uses a neural network
-provided by a model file to detect exploit patterns. The Kaizen Snort module
+snort_ml - Machine learning based exploit detector capable of detecting novel
+attacks fitting known vulnerability types. SnortML uses a neural network
+provided by a model file to detect exploit patterns. The SnortML module
 subscribes to HTTP events published by the HTTP inspector, performs inference
 on HTTP queries/posts, and generates events if the neural network detects
 an exploit.
diff --git a/src/network_inspectors/kaizen/CMakeLists.txt b/src/network_inspectors/kaizen/CMakeLists.txt
deleted file mode 100644 (file)
index bc59057..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-add_library(kaizen OBJECT
-    kaizen_engine.cc
-    kaizen_engine.h
-    kaizen_inspector.cc
-    kaizen_inspector.h
-    kaizen_module.cc
-    kaizen_module.h
-)
diff --git a/src/network_inspectors/kaizen/kaizen_engine.cc b/src/network_inspectors/kaizen/kaizen_engine.cc
deleted file mode 100644 (file)
index 6e60a77..0000000
+++ /dev/null
@@ -1,250 +0,0 @@
-//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
-//
-// This program is free software; you can redistribute it and/or modify it
-// under the terms of the GNU General Public License Version 2 as published
-// by the Free Software Foundation.  You may not use, modify or distribute
-// this program under any other version of the GNU General Public License.
-//
-// This program is distributed in the hope that it will be useful, but
-// WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
-// General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License along
-// with this program; if not, write to the Free Software Foundation, Inc.,
-// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
-//--------------------------------------------------------------------------
-// kaizen_engine.cc author Vitalii Horbatov <vhorbato@cisco.com>
-
-#ifdef HAVE_CONFIG_H
-#include "config.h"
-#endif
-
-#include "kaizen_engine.h"
-
-#include <cassert>
-#include <fstream>
-
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
-#include "framework/decode_data.h"
-#include "log/messages.h"
-#include "main/reload_tuner.h"
-#include "main/snort.h"
-#include "main/snort_config.h"
-#include "parser/parse_conf.h"
-#include "utils/util.h"
-
-using namespace snort;
-using namespace std;
-
-static THREAD_LOCAL BinaryClassifier* classifier = nullptr;
-
-static bool build_classifier(const string& model, BinaryClassifier*& dst)
-{
-    dst = new BinaryClassifier();
-
-    return dst->build(model);
-}
-
-//--------------------------------------------------------------------------
-// module
-//--------------------------------------------------------------------------
-
-static const Parameter kaizen_engine_params[] =
-{
-    { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to the model file" },
-    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
-};
-
-KaizenEngineModule::KaizenEngineModule() : Module(KZ_ENGINE_NAME, KZ_ENGINE_HELP, kaizen_engine_params) {}
-
-bool KaizenEngineModule::set(const char*, Value& v, SnortConfig*)
-{
-    if (v.is("http_param_model"))
-        conf.http_param_model_path = v.get_string();
-
-    return true;
-}
-
-//--------------------------------------------------------------------------
-// reload tuner
-//--------------------------------------------------------------------------
-
-class KaizenReloadTuner : public snort::ReloadResourceTuner
-{
-public:
-    explicit KaizenReloadTuner(const string& http_param_model) : http_param_model(http_param_model) {}
-    ~KaizenReloadTuner() override = default;
-
-    const char* name() const override
-    { return "KaizenReloadTuner"; }
-
-    bool tinit() override
-    {
-        delete classifier;
-
-        if (!build_classifier(http_param_model, classifier))
-            ErrorMessage("Can't build the classifier model.\n");
-
-        return false;
-    }
-
-    bool tune_packet_context() override
-    { return true; }
-
-    bool tune_idle_context() override
-    { return true; }
-
-private:
-    const string& http_param_model;
-};
-
-//--------------------------------------------------------------------------
-// inspector
-//--------------------------------------------------------------------------
-
-KaizenEngine::KaizenEngine(const KaizenEngineConfig& c) : config(c)
-{
-    http_param_model = read_model();
-
-    if (!validate_model())
-        ParseError("Can't build the classifier model %s.", config.http_param_model_path.c_str());
-}
-
-void KaizenEngine::show(const SnortConfig*) const
-{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
-
-string KaizenEngine::read_model()
-{
-    const char* hint = config.http_param_model_path.c_str();
-    string path;
-    size_t size = 0;
-
-    if (!get_config_file(hint, path) || !get_file_size(path, size))
-    {
-        ParseError("snort_ml_engine: could not read model file: %s", hint);
-        return {};
-    }
-
-    ifstream file(path, ios::binary);
-
-    if (!file.is_open())
-    {
-        ParseError("snort_ml_engine: could not read model file: %s", hint);
-        return {};
-    }
-
-    if (size == 0)
-    {
-        ParseError("snort_ml_engine: empty model file: %s", hint);
-        return {};
-    }
-
-    string buffer(size, '\0');
-    file.read(&buffer[0], streamsize(size));
-    return buffer;
-}
-
-bool KaizenEngine::validate_model()
-{
-    BinaryClassifier* test_classifier = nullptr;
-    bool res = build_classifier(http_param_model, test_classifier);
-    delete test_classifier;
-
-    return res;
-}
-
-void KaizenEngine::tinit()
-{ build_classifier(http_param_model, classifier); }
-
-void KaizenEngine::tterm()
-{
-    delete classifier;
-    classifier = nullptr;
-}
-
-void KaizenEngine::install_reload_handler(SnortConfig* sc)
-{ sc->register_reload_handler(new KaizenReloadTuner(http_param_model)); }
-
-BinaryClassifier* KaizenEngine::get_classifier()
-{ return classifier; }
-
-//--------------------------------------------------------------------------
-// api stuff
-//--------------------------------------------------------------------------
-
-static Module* mod_ctor()
-{ return new KaizenEngineModule; }
-
-static void mod_dtor(Module* m)
-{ delete m; }
-
-static Inspector* kaizen_engine_ctor(Module* m)
-{
-    KaizenEngineModule* mod = (KaizenEngineModule*)m;
-    return new KaizenEngine(mod->get_config());
-}
-
-static void kaizen_engine_dtor(Inspector* p)
-{
-    assert(p);
-    delete p;
-}
-
-static const InspectApi kaizen_engine_api =
-{
-    {
-        PT_INSPECTOR,
-        sizeof(InspectApi),
-        INSAPI_VERSION,
-        0,
-        API_RESERVED,
-        API_OPTIONS,
-        KZ_ENGINE_NAME,
-        KZ_ENGINE_HELP,
-        mod_ctor,
-        mod_dtor
-    },
-    IT_PASSIVE,
-    PROTO_BIT__NONE,  // proto_bits;
-    nullptr,  // buffers
-    nullptr,  // service
-    nullptr,  // pinit
-    nullptr,  // pterm
-    nullptr,  // tinit
-    nullptr,  // tterm
-    kaizen_engine_ctor,
-    kaizen_engine_dtor,
-    nullptr,  // ssn
-    nullptr   // reset
-};
-
-#ifdef BUILDING_SO
-SO_PUBLIC const BaseApi* snort_plugins[] =
-#else
-const BaseApi* nin_kaizen_engine[] =
-#endif
-{
-    &kaizen_engine_api.base,
-    nullptr
-};
-
-#ifdef UNIT_TEST
-
-#include "catch/snort_catch.h"
-
-#include <memory.h>
-
-TEST_CASE("Kaizen tuner name", "[kaizen_module]")
-{
-    const string http_param_model("model");
-    KaizenReloadTuner tuner(http_param_model);
-
-    REQUIRE(strcmp(tuner.name(), "KaizenReloadTuner") == 0);
-}
-
-#endif
index 7abdadda522ed627e41999ac87bfda14eaba2728..57808b251dc28ab1866f839062cc82c64578c715 100644 (file)
@@ -34,8 +34,8 @@ extern const BaseApi* nin_appid[];
 extern const BaseApi* nin_extractor[];
 
 #if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK)
-extern const BaseApi* nin_kaizen_engine[];
-extern const BaseApi* nin_kaizen[];
+extern const BaseApi* nin_snort_ml_engine[];
+extern const BaseApi* nin_snort_ml[];
 #endif
 
 extern const BaseApi* nin_port_scan[];
@@ -62,8 +62,8 @@ void load_network_inspectors()
     PluginManager::load_plugins(nin_extractor);
 
 #if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK)
-    PluginManager::load_plugins(nin_kaizen_engine);
-    PluginManager::load_plugins(nin_kaizen);
+    PluginManager::load_plugins(nin_snort_ml_engine);
+    PluginManager::load_plugins(nin_snort_ml);
 #endif
 
     PluginManager::load_plugins(nin_port_scan);
diff --git a/src/network_inspectors/snort_ml/CMakeLists.txt b/src/network_inspectors/snort_ml/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cf0e260
--- /dev/null
@@ -0,0 +1,8 @@
+add_library(snort_ml OBJECT
+    snort_ml_engine.cc
+    snort_ml_engine.h
+    snort_ml_inspector.cc
+    snort_ml_inspector.h
+    snort_ml_module.cc
+    snort_ml_module.h
+)
diff --git a/src/network_inspectors/snort_ml/snort_ml_engine.cc b/src/network_inspectors/snort_ml/snort_ml_engine.cc
new file mode 100644 (file)
index 0000000..6b3c63c
--- /dev/null
@@ -0,0 +1,283 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// snort_ml_engine.cc author Vitalii Horbatov <vhorbato@cisco.com>
+//                    author Brandon Stultz <brastult@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "snort_ml_engine.h"
+
+#include <cassert>
+#include <fstream>
+
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
+#include "framework/decode_data.h"
+#include "helpers/directory.h"
+#include "log/messages.h"
+#include "main/reload_tuner.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+#include "parser/parse_conf.h"
+#include "utils/util.h"
+
+using namespace snort;
+using namespace std;
+
+static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr;
+
+static bool build_classifiers(const vector<string>& models,
+    libml::BinaryClassifierSet*& set)
+{
+    set = new libml::BinaryClassifierSet();
+
+    return set->build(models);
+}
+
+//--------------------------------------------------------------------------
+// module
+//--------------------------------------------------------------------------
+
+static const Parameter snort_ml_engine_params[] =
+{
+    { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" },
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+
+bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*)
+{
+    if (v.is("http_param_model"))
+        conf.http_param_model_path = v.get_string();
+
+    return true;
+}
+
+//--------------------------------------------------------------------------
+// reload tuner
+//--------------------------------------------------------------------------
+
+class SnortMLReloadTuner : public snort::ReloadResourceTuner
+{
+public:
+    explicit SnortMLReloadTuner(const vector<string>& models)
+        : http_param_models(models) {}
+
+    ~SnortMLReloadTuner() override = default;
+
+    const char* name() const override
+    { return "SnortMLReloadTuner"; }
+
+    bool tinit() override
+    {
+        delete classifiers;
+
+        if (!build_classifiers(http_param_models, classifiers))
+            ErrorMessage("Could not build classifiers.\n");
+
+        return false;
+    }
+
+    bool tune_packet_context() override
+    { return true; }
+
+    bool tune_idle_context() override
+    { return true; }
+
+private:
+    const vector<string>& http_param_models;
+};
+
+//--------------------------------------------------------------------------
+// inspector
+//--------------------------------------------------------------------------
+
+SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c)
+{
+    if (!read_models() || !validate_models())
+        ParseError("Could not build classifiers.");
+}
+
+void SnortMLEngine::show(const SnortConfig*) const
+{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
+
+bool SnortMLEngine::read_models()
+{
+    const char* hint = config.http_param_model_path.c_str();
+    string path;
+
+    if (!get_config_file(hint, path))
+    {
+        ParseError("snort_ml_engine: could not read model file(s): %s", hint);
+        return false;
+    }
+
+    if (!is_directory_path(path))
+    {
+        if (!read_model(path))
+        {
+            ParseError("snort_ml_engine: could not read model file: %s", path.c_str());
+            return false;
+        }
+
+        return true;
+    }
+
+    Directory model_dir(path.c_str());
+
+    if (model_dir.error_on_open())
+    {
+        ParseError("snort_ml_engine: could not read model dir: %s", path.c_str());
+        return false;
+    }
+
+    while (const char* f = model_dir.next())
+    {
+        if (!read_model(f))
+        {
+            ParseError("snort_ml_engine: could not read model: %s", f);
+            return false;
+        }
+    }
+
+    return !http_param_models.empty();
+}
+
+bool SnortMLEngine::read_model(const string& path)
+{
+    size_t size = 0;
+
+    if (!get_file_size(path, size))
+        return false;
+
+    ifstream file(path, ios::binary);
+
+    if (!file.is_open() || size == 0)
+        return false;
+
+    string buffer(size, '\0');
+    file.read(&buffer[0], streamsize(size));
+
+    http_param_models.push_back(move(buffer));
+    return true;
+}
+
+bool SnortMLEngine::validate_models()
+{
+    libml::BinaryClassifierSet* set = nullptr;
+    bool res = build_classifiers(http_param_models, set);
+    delete set;
+
+    return res;
+}
+
+void SnortMLEngine::tinit()
+{ build_classifiers(http_param_models, classifiers); }
+
+void SnortMLEngine::tterm()
+{
+    delete classifiers;
+    classifiers = nullptr;
+}
+
+void SnortMLEngine::install_reload_handler(SnortConfig* sc)
+{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); }
+
+libml::BinaryClassifierSet* SnortMLEngine::get_classifiers()
+{ return classifiers; }
+
+//--------------------------------------------------------------------------
+// api stuff
+//--------------------------------------------------------------------------
+
+static Module* mod_ctor()
+{ return new SnortMLEngineModule; }
+
+static void mod_dtor(Module* m)
+{ delete m; }
+
+static Inspector* snort_ml_engine_ctor(Module* m)
+{
+    SnortMLEngineModule* mod = (SnortMLEngineModule*)m;
+    return new SnortMLEngine(mod->get_config());
+}
+
+static void snort_ml_engine_dtor(Inspector* p)
+{
+    assert(p);
+    delete p;
+}
+
+static const InspectApi snort_ml_engine_api =
+{
+    {
+        PT_INSPECTOR,
+        sizeof(InspectApi),
+        INSAPI_VERSION,
+        0,
+        API_RESERVED,
+        API_OPTIONS,
+        SNORT_ML_ENGINE_NAME,
+        SNORT_ML_ENGINE_HELP,
+        mod_ctor,
+        mod_dtor
+    },
+    IT_PASSIVE,
+    PROTO_BIT__NONE,  // proto_bits;
+    nullptr,  // buffers
+    nullptr,  // service
+    nullptr,  // pinit
+    nullptr,  // pterm
+    nullptr,  // tinit
+    nullptr,  // tterm
+    snort_ml_engine_ctor,
+    snort_ml_engine_dtor,
+    nullptr,  // ssn
+    nullptr   // reset
+};
+
+#ifdef BUILDING_SO
+SO_PUBLIC const BaseApi* snort_plugins[] =
+#else
+const BaseApi* nin_snort_ml_engine[] =
+#endif
+{
+    &snort_ml_engine_api.base,
+    nullptr
+};
+
+#ifdef UNIT_TEST
+
+#include "catch/snort_catch.h"
+
+#include <memory.h>
+
+TEST_CASE("SnortML tuner name", "[snort_ml_module]")
+{
+    const vector<string> models = { "model" };
+    SnortMLReloadTuner tuner(models);
+
+    REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0);
+}
+
+#endif
similarity index 54%
rename from src/network_inspectors/kaizen/kaizen_engine.h
rename to src/network_inspectors/snort_ml/snort_ml_engine.h
index 7c7381ba5244db5a17f81c2744138804d3f50739..0d77d16642f6d004a09e140cc564aa12d47023f5 100644 (file)
@@ -1,5 +1,5 @@
 //--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
 //
 // This program is free software; you can redistribute it and/or modify it
 // under the terms of the GNU General Public License Version 2 as published
 // with this program; if not, write to the Free Software Foundation, Inc.,
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
-// kaizen_engine.h author Vitalii Horbatov <vhorbato@cisco.com>
+// snort_ml_engine.h author Vitalii Horbatov <vhorbato@cisco.com>
+//                   author Brandon Stultz <brastult@cisco.com>
 
-#ifndef KAIZEN_ENGINE_H
-#define KAIZEN_ENGINE_H
+#ifndef SNORT_ML_ENGINE_H
+#define SNORT_ML_ENGINE_H
 
 #include "framework/module.h"
 #include "framework/inspector.h"
 
+#define SNORT_ML_ENGINE_NAME "snort_ml_engine"
+#define SNORT_ML_ENGINE_HELP "configure machine learning engine settings"
 
-#define KZ_ENGINE_NAME "snort_ml_engine"
-#define KZ_ENGINE_HELP "configure machine learning engine settings"
+namespace libml
+{
+    class BinaryClassifierSet;
+}
 
-class BinaryClassifier;
-struct KaizenEngineConfig
+struct SnortMLEngineConfig
 {
     std::string http_param_model_path;
 };
 
-class KaizenEngineModule : public snort::Module
+class SnortMLEngineModule : public snort::Module
 {
 public:
-    KaizenEngineModule();
+    SnortMLEngineModule();
 
     bool set(const char*, snort::Value&, snort::SnortConfig*) override;
 
     Usage get_usage() const override
     { return GLOBAL; }
 
-    const KaizenEngineConfig& get_config()
+    const SnortMLEngineConfig& get_config()
     { return conf; }
 
 private:
-    KaizenEngineConfig conf;
+    SnortMLEngineConfig conf;
 };
 
-
-class KaizenEngine : public snort::Inspector
+class SnortMLEngine : public snort::Inspector
 {
 public:
-    KaizenEngine(const KaizenEngineConfig&);
+    SnortMLEngine(const SnortMLEngineConfig&);
 
     void show(const snort::SnortConfig*) const override;
     void eval(snort::Packet*) override {}
@@ -64,40 +67,48 @@ public:
 
     void install_reload_handler(snort::SnortConfig*) override;
 
-    static BinaryClassifier* get_classifier();
+    static libml::BinaryClassifierSet* get_classifiers();
 
 private:
-    std::string read_model();
-    bool validate_model();
+    bool read_models();
+    bool read_model(const std::string&);
 
-    KaizenEngineConfig config;
-    std::string http_param_model;
-};
+    bool validate_models();
 
+    SnortMLEngineConfig config;
+    std::vector<std::string> http_param_models;
+};
 
-// Mock Classifier for tests if LibML absents.
-// However, when REG_TEST is undefined, the entire code below won't be executed.
-// Check the plugin type provided in kaizen_engine_api in the cc file
+// Mock BinaryClassifierSet for tests if LibML is absent.
+// The code below won't be executed if REG_TEST is undefined.
+// Check the plugin type provided in the snort_ml_engine.cc file.
 #ifndef HAVE_LIBML
-class BinaryClassifier
+namespace libml
+{
+
+class BinaryClassifierSet
 {
 public:
-    bool build(const std::string& model)
+    bool build(const std::vector<std::string>& models)
     {
-        pattern = model;
+        if (!models.empty())
+            pattern = models[0];
+
         return pattern != "error";
     }
 
-    bool run(const char* ptr, size_t len, float& threshold)
+    bool run(const char* ptr, size_t len, float& out)
     {
         std::string data(ptr, len);
-        threshold = std::string::npos == data.find(pattern) ? 0.0f : 1.0f;
+        out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
         return pattern != "fail";
     }
 
 private:
     std::string pattern;
 };
+
+}
 #endif
 
 #endif
similarity index 64%
rename from src/network_inspectors/kaizen/kaizen_inspector.cc
rename to src/network_inspectors/snort_ml/snort_ml_inspector.cc
index 863eb59f37d9bbd9d737e188093116cc38d0896f..9d27a5817a93cc9f1517180e807d5ff7e9c0cda6 100644 (file)
@@ -1,5 +1,5 @@
 //--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
 //
 // This program is free software; you can redistribute it and/or modify it
 // under the terms of the GNU General Public License Version 2 as published
 // with this program; if not, write to the Free Software Foundation, Inc.,
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
-// kaizen_inspector.cc author Brandon Stultz <brastult@cisco.com>
+// snort_ml_inspector.cc author Brandon Stultz <brastult@cisco.com>
 
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
 
-#include "kaizen_inspector.h"
+#include "snort_ml_inspector.h"
 
 #include <cassert>
 
 #include "pub_sub/http_request_body_event.h"
 #include "utils/util.h"
 
-#include "kaizen_engine.h"
+#include "snort_ml_engine.h"
 
 using namespace snort;
 using namespace std;
 
-THREAD_LOCAL KaizenStats kaizen_stats;
-THREAD_LOCAL ProfileStats kaizen_prof;
+THREAD_LOCAL SnortMLStats snort_ml_stats;
+THREAD_LOCAL ProfileStats snort_ml_prof;
 
 //--------------------------------------------------------------------------
 // HTTP body event handler
@@ -51,22 +51,22 @@ THREAD_LOCAL ProfileStats kaizen_prof;
 class HttpBodyHandler : public DataHandler
 {
 public:
-    HttpBodyHandler(Kaizen& kz)
-        : DataHandler(KZ_NAME), inspector(kz) {}
+    HttpBodyHandler(SnortML& ml)
+        : DataHandler(SNORT_ML_NAME), inspector(ml) {}
 
     void handle(DataEvent& de, Flow*) override;
 
 private:
-    Kaizen& inspector;
+    SnortML& inspector;
 };
 
 void HttpBodyHandler::handle(DataEvent& de, Flow*)
 {
     // cppcheck-suppress unreadVariable
-    Profile profile(kaizen_prof);
+    Profile profile(snort_ml_prof);
 
-    BinaryClassifier* classifier = KaizenEngine::get_classifier();
-    KaizenConfig config = inspector.get_config();
+    libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
+    SnortMLConfig config = inspector.get_config();
     HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de;
 
     if (he->is_mime())
@@ -80,25 +80,25 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*)
 
     const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len);
 
-    assert(classifier);
+    assert(classifiers);
 
     float output = 0.0;
 
-    kaizen_stats.libml_calls++;
+    snort_ml_stats.libml_calls++;
 
-    if (!classifier->run(body, len, output))
+    if (!classifiers->run(body, len, output))
         return;
 
-    kaizen_stats.client_body_bytes += len;
+    snort_ml_stats.client_body_bytes += len;
 
-    debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
-    debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
 
     if ((double)output > config.http_param_threshold)
     {
-        kaizen_stats.client_body_alerts++;
-        debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
-        DetectionEngine::queue_event(KZ_GID, KZ_SID);
+        snort_ml_stats.client_body_alerts++;
+        debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
+        DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID);
     }
 }
 
@@ -109,22 +109,22 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*)
 class HttpUriHandler : public DataHandler
 {
 public:
-    HttpUriHandler(Kaizen& kz)
-        : DataHandler(KZ_NAME), inspector(kz) {}
+    HttpUriHandler(SnortML& ml)
+        : DataHandler(SNORT_ML_NAME), inspector(ml) {}
 
     void handle(DataEvent&, Flow*) override;
 
 private:
-    Kaizen& inspector;
+    SnortML& inspector;
 };
 
 void HttpUriHandler::handle(DataEvent& de, Flow*)
 {
     // cppcheck-suppress unreadVariable
-    Profile profile(kaizen_prof);
+    Profile profile(snort_ml_prof);
 
-    BinaryClassifier* classifier = KaizenEngine::get_classifier();
-    KaizenConfig config = inspector.get_config();
+    libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
+    SnortMLConfig config = inspector.get_config();
     HttpEvent* he = (HttpEvent*)&de;
 
     int32_t query_len = 0;
@@ -135,25 +135,25 @@ void HttpUriHandler::handle(DataEvent& de, Flow*)
 
     const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len);
 
-    assert(classifier);
+    assert(classifiers);
 
     float output = 0.0;
 
-    kaizen_stats.libml_calls++;
+    snort_ml_stats.libml_calls++;
 
-    if (!classifier->run(query, (size_t)len, output))
+    if (!classifiers->run(query, len, output))
         return;
 
-    kaizen_stats.uri_bytes += len;
+    snort_ml_stats.uri_bytes += len;
 
-    debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
-    debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
+    debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
 
     if ((double)output > config.http_param_threshold)
     {
-        kaizen_stats.uri_alerts++;
-        debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
-        DetectionEngine::queue_event(KZ_GID, KZ_SID);
+        snort_ml_stats.uri_alerts++;
+        debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
+        DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID);
     }
 }
 
@@ -161,14 +161,14 @@ void HttpUriHandler::handle(DataEvent& de, Flow*)
 // inspector
 //--------------------------------------------------------------------------
 
-void Kaizen::show(const SnortConfig*) const
+void SnortML::show(const SnortConfig*) const
 {
     ConfigLogger::log_limit("uri_depth", config.uri_depth, -1);
     ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1);
     ConfigLogger::log_value("http_param_threshold", config.http_param_threshold);
 }
 
-bool Kaizen::configure(SnortConfig* sc)
+bool SnortML::configure(SnortConfig* sc)
 {
     if (config.uri_depth != 0)
         DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this));
@@ -176,9 +176,9 @@ bool Kaizen::configure(SnortConfig* sc)
     if (config.client_body_depth != 0)
         DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this));
 
-    if(!InspectorManager::get_inspector(KZ_ENGINE_NAME, true, sc))
+    if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc))
     {
-        ParseError("snort_ml requires %s to be configured in the global policy.", KZ_ENGINE_NAME);
+        ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME);
         return false;
     }
 
@@ -190,24 +190,24 @@ bool Kaizen::configure(SnortConfig* sc)
 //--------------------------------------------------------------------------
 
 static Module* mod_ctor()
-{ return new KaizenModule; }
+{ return new SnortMLModule; }
 
 static void mod_dtor(Module* m)
 { delete m; }
 
-static Inspector* kaizen_ctor(Module* m)
+static Inspector* snort_ml_ctor(Module* m)
 {
-    KaizenModule* km = (KaizenModule*)m;
-    return new Kaizen(km->get_conf());
+    SnortMLModule* km = (SnortMLModule*)m;
+    return new SnortML(km->get_conf());
 }
 
-static void kaizen_dtor(Inspector* p)
+static void snort_ml_dtor(Inspector* p)
 {
     assert(p);
     delete p;
 }
 
-static const InspectApi kaizen_api =
+static const InspectApi snort_ml_api =
 {
     {
         PT_INSPECTOR,
@@ -216,8 +216,8 @@ static const InspectApi kaizen_api =
         0,
         API_RESERVED,
         API_OPTIONS,
-        KZ_NAME,
-        KZ_HELP,
+        SNORT_ML_NAME,
+        SNORT_ML_HELP,
         mod_ctor,
         mod_dtor
     },
@@ -229,8 +229,8 @@ static const InspectApi kaizen_api =
     nullptr,  // pterm
     nullptr,  // tinit
     nullptr,  // tterm
-    kaizen_ctor,
-    kaizen_dtor,
+    snort_ml_ctor,
+    snort_ml_dtor,
     nullptr,  // ssn
     nullptr   // reset
 };
@@ -238,9 +238,9 @@ static const InspectApi kaizen_api =
 #ifdef BUILDING_SO
 SO_PUBLIC const BaseApi* snort_plugins[] =
 #else
-const BaseApi* nin_kaizen[] =
+const BaseApi* nin_snort_ml[] =
 #endif
 {
-    &kaizen_api.base,
+    &snort_ml_api.base,
     nullptr
 };
similarity index 76%
rename from src/network_inspectors/kaizen/kaizen_inspector.h
rename to src/network_inspectors/snort_ml/snort_ml_inspector.h
index d54d72a62ffbf33ee8a326e5f7f5013db1e4cbdc..f7f97fca9214ab42bd75ed933d7074ebe778e24b 100644 (file)
@@ -1,5 +1,5 @@
 //--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
 //
 // This program is free software; you can redistribute it and/or modify it
 // under the terms of the GNU General Public License Version 2 as published
 // with this program; if not, write to the Free Software Foundation, Inc.,
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
-// kaizen_inspector.h author Brandon Stultz <brastult@cisco.com>
+// snort_ml_inspector.h author Brandon Stultz <brastult@cisco.com>
 
-#ifndef KAIZEN_INSPECTOR_H
-#define KAIZEN_INSPECTOR_H
+#ifndef SNORT_ML_INSPECTOR_H
+#define SNORT_ML_INSPECTOR_H
 
 #include <string>
 #include <utility>
 
 #include "framework/inspector.h"
 
-#include "kaizen_module.h"
+#include "snort_ml_module.h"
 
-class Kaizen : public snort::Inspector
+class SnortML : public snort::Inspector
 {
 public:
-    Kaizen(const KaizenConfig& c) : config(c) { }
+    SnortML(const SnortMLConfig& c) : config(c) { }
 
     void show(const snort::SnortConfig*) const override;
     void eval(snort::Packet*) override {}
     bool configure(snort::SnortConfig*) override;
 
-    const KaizenConfig& get_config()
+    const SnortMLConfig& get_config()
     { return config; }
 private:
-    KaizenConfig config;
+    SnortMLConfig config;
 };
 
 #endif
similarity index 74%
rename from src/network_inspectors/kaizen/kaizen_module.cc
rename to src/network_inspectors/snort_ml/snort_ml_module.cc
index 554b8ed41fd49f5a8061eae27b00e0eed4ff6a6e..d0c4cd2917e925812af31eca7d2aaad83cbb333b 100644 (file)
@@ -1,5 +1,5 @@
 //--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
 //
 // This program is free software; you can redistribute it and/or modify it
 // under the terms of the GNU General Public License Version 2 as published
 // with this program; if not, write to the Free Software Foundation, Inc.,
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
-// kaizen_module.cc author Brandon Stultz <brastult@cisco.com>
+// snort_ml_module.cc author Brandon Stultz <brastult@cisco.com>
 
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
 
-#include "kaizen_module.h"
+#include "snort_ml_module.h"
 
 #include "log/messages.h"
 #include "service_inspectors/http_inspect/http_field.h"
 
 using namespace snort;
 
-THREAD_LOCAL const Trace* kaizen_trace = nullptr;
+THREAD_LOCAL const Trace* snort_ml_trace = nullptr;
 
-static const Parameter kaizen_params[] =
+static const Parameter snort_ml_params[] =
 {
     { "uri_depth", Parameter::PT_INT, "-1:max31", "-1",
       "number of input HTTP URI bytes to scan (-1 unlimited)" },
@@ -44,9 +44,9 @@ static const Parameter kaizen_params[] =
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
-static const RuleMap kaizen_rules[] =
+static const RuleMap snort_ml_rules[] =
 {
-    { KZ_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" },
+    { SNORT_ML_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" },
     { 0, nullptr }
 };
 
@@ -61,7 +61,7 @@ static const PegInfo peg_names[] =
 };
 
 #ifdef DEBUG_MSGS
-static const TraceOption kaizen_trace_options[] =
+static const TraceOption snort_ml_trace_options[] =
 {
     { "classifier", TRACE_CLASSIFIER, "enable Snort ML classifier trace logging" },
     { nullptr, 0, nullptr }
@@ -72,9 +72,9 @@ static const TraceOption kaizen_trace_options[] =
 // module
 //--------------------------------------------------------------------------
 
-KaizenModule::KaizenModule() : Module(KZ_NAME, KZ_HELP, kaizen_params) {}
+SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
 
-bool KaizenModule::set(const char*, Value& v, SnortConfig*)
+bool SnortMLModule::set(const char*, Value& v, SnortConfig*)
 {
     static_assert(std::is_same<decltype((Field().length())), decltype(conf.uri_depth)>::value,
         "Field::length maximum value should not exceed uri_depth type range");
@@ -91,7 +91,7 @@ bool KaizenModule::set(const char*, Value& v, SnortConfig*)
     return true;
 }
 
-bool KaizenModule::end(const char*, int, snort::SnortConfig*)
+bool SnortMLModule::end(const char*, int, snort::SnortConfig*)
 {
     if (!conf.uri_depth && !conf.client_body_depth)
         ParseWarning(WARN_CONF,
@@ -100,26 +100,26 @@ bool KaizenModule::end(const char*, int, snort::SnortConfig*)
     return true;
 }
 
-const RuleMap* KaizenModule::get_rules() const
-{ return kaizen_rules; }
+const RuleMap* SnortMLModule::get_rules() const
+{ return snort_ml_rules; }
 
-const PegInfo* KaizenModule::get_pegs() const
+const PegInfo* SnortMLModule::get_pegs() const
 { return peg_names; }
 
-PegCount* KaizenModule::get_counts() const
-{ return (PegCount*)&kaizen_stats; }
+PegCount* SnortMLModule::get_counts() const
+{ return (PegCount*)&snort_ml_stats; }
 
-ProfileStats* KaizenModule::get_profile() const
-{ return &kaizen_prof; }
+ProfileStats* SnortMLModule::get_profile() const
+{ return &snort_ml_prof; }
 
-void KaizenModule::set_trace(const Trace* trace) const
-{ kaizen_trace = trace; }
+void SnortMLModule::set_trace(const Trace* trace) const
+{ snort_ml_trace = trace; }
 
-const TraceOption* KaizenModule::get_trace_options() const
+const TraceOption* SnortMLModule::get_trace_options() const
 {
 #ifndef DEBUG_MSGS
     return nullptr;
 #else
-    return kaizen_trace_options;
+    return snort_ml_trace_options;
 #endif
 }
similarity index 74%
rename from src/network_inspectors/kaizen/kaizen_module.h
rename to src/network_inspectors/snort_ml/snort_ml_module.h
index 59624bddf3e4652d0ffc4aef76a2d42fd47721f7..b842a06ecaf34de8025a3391e5f9f93957dca3a8 100644 (file)
@@ -1,5 +1,5 @@
 //--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
 //
 // This program is free software; you can redistribute it and/or modify it
 // under the terms of the GNU General Public License Version 2 as published
 // with this program; if not, write to the Free Software Foundation, Inc.,
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
-// kaizen_module.h author Brandon Stultz <brastult@cisco.com>
+// snort_ml_module.h author Brandon Stultz <brastult@cisco.com>
 
-#ifndef KAIZEN_MODULE_H
-#define KAIZEN_MODULE_H
+#ifndef SNORT_ML_MODULE_H
+#define SNORT_ML_MODULE_H
 
 #include "framework/module.h"
 #include "main/thread.h"
 #include "profiler/profiler.h"
 #include "trace/trace_api.h"
 
-#define KZ_GID 411
-#define KZ_SID 1
+#define SNORT_ML_GID 411
+#define SNORT_ML_SID 1
 
-#define KZ_NAME "snort_ml"
-#define KZ_HELP "machine learning based exploit detector"
+#define SNORT_ML_NAME "snort_ml"
+#define SNORT_ML_HELP "machine learning based exploit detector"
 
 enum { TRACE_CLASSIFIER };
 
-struct KaizenStats
+struct SnortMLStats
 {
     PegCount uri_alerts;
     PegCount client_body_alerts;
@@ -42,11 +42,11 @@ struct KaizenStats
     PegCount libml_calls;
 };
 
-extern THREAD_LOCAL KaizenStats kaizen_stats;
-extern THREAD_LOCAL snort::ProfileStats kaizen_prof;
-extern THREAD_LOCAL const snort::Trace* kaizen_trace;
+extern THREAD_LOCAL SnortMLStats snort_ml_stats;
+extern THREAD_LOCAL snort::ProfileStats snort_ml_prof;
+extern THREAD_LOCAL const snort::Trace* snort_ml_trace;
 
-struct KaizenConfig
+struct SnortMLConfig
 {
     std::string http_param_model_path;
     double http_param_threshold;
@@ -54,19 +54,19 @@ struct KaizenConfig
     int32_t client_body_depth;
 };
 
-class KaizenModule : public snort::Module
+class SnortMLModule : public snort::Module
 {
 public:
-    KaizenModule();
+    SnortMLModule();
 
     bool set(const char*, snort::Value&, snort::SnortConfig*) override;
     bool end(const char*, int, snort::SnortConfig*) override;
 
-    const KaizenConfig& get_conf() const
+    const SnortMLConfig& get_conf() const
     { return conf; }
 
     unsigned get_gid() const override
-    { return KZ_GID; }
+    { return SNORT_ML_GID; }
 
     const snort::RuleMap* get_rules() const override;
 
@@ -82,7 +82,7 @@ public:
     const snort::TraceOption* get_trace_options() const override;
 
 private:
-    KaizenConfig conf = {};
+    SnortMLConfig conf = {};
 };
 
 #endif
index 5175ebb0c6e79b026760720c6574a31c96f94653..ad6c2d27ec784140c1bedcac88196110dc890c83 100644 (file)
 
 #ifdef _WIN32
 #define ISREG(m) (((m) & _S_IFMT) == _S_IFREG)
+#define ISDIR(m) (((m) & _S_IFMT) == _S_IFDIR)
 #else
 #define ISREG(m) S_ISREG(m)
+#define ISDIR(m) S_ISDIR(m)
 #endif
 
 using namespace snort;
@@ -83,6 +85,16 @@ bool get_file_size(const std::string& path, size_t& size)
     return true;
 }
 
+bool is_directory_path(const std::string& path)
+{
+    struct STAT sb;
+
+    if (STAT(path.c_str(), &sb))
+        return false;
+
+    return ISDIR(sb.st_mode);
+}
+
 namespace snort
 {
 const char* get_error(int errnum)
index c6916d6c01a0c7e10086ac2c3be47321e6984df5..9f91be245f7dbe031e678e13041454fa8722c7d2 100644 (file)
@@ -65,6 +65,7 @@
 
 unsigned int get_random_seed();
 bool get_file_size(const std::string&, size_t&);
+bool is_directory_path(const std::string&);
 
 namespace
 {