-find_library(ML_LIBRARIES NAMES ml_static HINTS ${ML_LIBRARIES_DIR_HINT})
-find_path(ML_INCLUDE_DIRS libml.h HINTS ${ML_INCLUDE_DIR_HINT})
+find_package(PkgConfig)
+pkg_check_modules(PC_ML libml_static>=2.0.0)
+
+find_path(ML_INCLUDE_DIRS
+ libml.h
+ HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR}
+)
+
+find_library(ML_LIBRARIES
+ NAMES ml_static
+ HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR}
+)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(ML
DEFAULT_MSG
- ML_LIBRARIES
ML_INCLUDE_DIRS
+ ML_LIBRARIES
)
if (ML_FOUND AND NOT USE_LIBML_MOCK)
#include <jemalloc/jemalloc.h>
#endif
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
#ifdef HAVE_LIBUNWIND
#define UNW_LOCAL_ONLY
#include <libunwind.h>
size_t sz = sizeof(jv);
mallctl("version", &jv, &sz, NULL, 0);
LogMessage(" Using Jemalloc version %s\n", jv);
+#endif
+#ifdef HAVE_LIBML
+ LogMessage(" Using LibML version %s\n", libml::version());
#endif
LogMessage(" Using %s\n", pcap_lib_version());
LogMessage(" Using LuaJIT version %s\n", ljv);
vs.push_back(lzma_version_string());
#endif
#ifdef HAVE_LIBML
- vs.push_back(libml_version());
+ vs.push_back(libml::version());
#endif
lua_createtable(L, 0, vs.size());
add_subdirectory(extractor)
if ( HAVE_LIBML OR USE_LIBML_MOCK )
- add_subdirectory(kaizen)
+ add_subdirectory(snort_ml)
endif()
add_subdirectory(normalize)
endif()
if ( HAVE_LIBML OR USE_LIBML_MOCK )
- set(KAIZEN_STATIC_OBJ
- $<TARGET_OBJECTS:kaizen>
+ set(SNORT_ML_STATIC_OBJ
+ $<TARGET_OBJECTS:snort_ml>
)
endif()
$<TARGET_OBJECTS:appid>
$<TARGET_OBJECTS:binder>
$<TARGET_OBJECTS:extractor>
- ${KAIZEN_STATIC_OBJ}
+ ${SNORT_ML_STATIC_OBJ}
$<TARGET_OBJECTS:normalize>
$<TARGET_OBJECTS:port_scan>
$<TARGET_OBJECTS:reputation>
packet_capture - A tool for dumping the wire packets that Snort receives.
-kaizen - Machine learning based exploit detector capable of detecting novel
-attacks fitting known vulnerability types. Kaizen uses a neural network
-provided by a model file to detect exploit patterns. The Kaizen Snort module
+snort_ml - Machine learning based exploit detector capable of detecting novel
+attacks fitting known vulnerability types. SnortML uses a neural network
+provided by a model file to detect exploit patterns. The SnortML module
subscribes to HTTP events published by the HTTP inspector, performs inference
on HTTP queries/posts, and generates events if the neural network detects
an exploit.
+++ /dev/null
-add_library(kaizen OBJECT
- kaizen_engine.cc
- kaizen_engine.h
- kaizen_inspector.cc
- kaizen_inspector.h
- kaizen_module.cc
- kaizen_module.h
-)
+++ /dev/null
-//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
-//
-// This program is free software; you can redistribute it and/or modify it
-// under the terms of the GNU General Public License Version 2 as published
-// by the Free Software Foundation. You may not use, modify or distribute
-// this program under any other version of the GNU General Public License.
-//
-// This program is distributed in the hope that it will be useful, but
-// WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-// General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License along
-// with this program; if not, write to the Free Software Foundation, Inc.,
-// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
-//--------------------------------------------------------------------------
-// kaizen_engine.cc author Vitalii Horbatov <vhorbato@cisco.com>
-
-#ifdef HAVE_CONFIG_H
-#include "config.h"
-#endif
-
-#include "kaizen_engine.h"
-
-#include <cassert>
-#include <fstream>
-
-#ifdef HAVE_LIBML
-#include <libml.h>
-#endif
-
-#include "framework/decode_data.h"
-#include "log/messages.h"
-#include "main/reload_tuner.h"
-#include "main/snort.h"
-#include "main/snort_config.h"
-#include "parser/parse_conf.h"
-#include "utils/util.h"
-
-using namespace snort;
-using namespace std;
-
-static THREAD_LOCAL BinaryClassifier* classifier = nullptr;
-
-static bool build_classifier(const string& model, BinaryClassifier*& dst)
-{
- dst = new BinaryClassifier();
-
- return dst->build(model);
-}
-
-//--------------------------------------------------------------------------
-// module
-//--------------------------------------------------------------------------
-
-static const Parameter kaizen_engine_params[] =
-{
- { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to the model file" },
- { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
-};
-
-KaizenEngineModule::KaizenEngineModule() : Module(KZ_ENGINE_NAME, KZ_ENGINE_HELP, kaizen_engine_params) {}
-
-bool KaizenEngineModule::set(const char*, Value& v, SnortConfig*)
-{
- if (v.is("http_param_model"))
- conf.http_param_model_path = v.get_string();
-
- return true;
-}
-
-//--------------------------------------------------------------------------
-// reload tuner
-//--------------------------------------------------------------------------
-
-class KaizenReloadTuner : public snort::ReloadResourceTuner
-{
-public:
- explicit KaizenReloadTuner(const string& http_param_model) : http_param_model(http_param_model) {}
- ~KaizenReloadTuner() override = default;
-
- const char* name() const override
- { return "KaizenReloadTuner"; }
-
- bool tinit() override
- {
- delete classifier;
-
- if (!build_classifier(http_param_model, classifier))
- ErrorMessage("Can't build the classifier model.\n");
-
- return false;
- }
-
- bool tune_packet_context() override
- { return true; }
-
- bool tune_idle_context() override
- { return true; }
-
-private:
- const string& http_param_model;
-};
-
-//--------------------------------------------------------------------------
-// inspector
-//--------------------------------------------------------------------------
-
-KaizenEngine::KaizenEngine(const KaizenEngineConfig& c) : config(c)
-{
- http_param_model = read_model();
-
- if (!validate_model())
- ParseError("Can't build the classifier model %s.", config.http_param_model_path.c_str());
-}
-
-void KaizenEngine::show(const SnortConfig*) const
-{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
-
-string KaizenEngine::read_model()
-{
- const char* hint = config.http_param_model_path.c_str();
- string path;
- size_t size = 0;
-
- if (!get_config_file(hint, path) || !get_file_size(path, size))
- {
- ParseError("snort_ml_engine: could not read model file: %s", hint);
- return {};
- }
-
- ifstream file(path, ios::binary);
-
- if (!file.is_open())
- {
- ParseError("snort_ml_engine: could not read model file: %s", hint);
- return {};
- }
-
- if (size == 0)
- {
- ParseError("snort_ml_engine: empty model file: %s", hint);
- return {};
- }
-
- string buffer(size, '\0');
- file.read(&buffer[0], streamsize(size));
- return buffer;
-}
-
-bool KaizenEngine::validate_model()
-{
- BinaryClassifier* test_classifier = nullptr;
- bool res = build_classifier(http_param_model, test_classifier);
- delete test_classifier;
-
- return res;
-}
-
-void KaizenEngine::tinit()
-{ build_classifier(http_param_model, classifier); }
-
-void KaizenEngine::tterm()
-{
- delete classifier;
- classifier = nullptr;
-}
-
-void KaizenEngine::install_reload_handler(SnortConfig* sc)
-{ sc->register_reload_handler(new KaizenReloadTuner(http_param_model)); }
-
-BinaryClassifier* KaizenEngine::get_classifier()
-{ return classifier; }
-
-//--------------------------------------------------------------------------
-// api stuff
-//--------------------------------------------------------------------------
-
-static Module* mod_ctor()
-{ return new KaizenEngineModule; }
-
-static void mod_dtor(Module* m)
-{ delete m; }
-
-static Inspector* kaizen_engine_ctor(Module* m)
-{
- KaizenEngineModule* mod = (KaizenEngineModule*)m;
- return new KaizenEngine(mod->get_config());
-}
-
-static void kaizen_engine_dtor(Inspector* p)
-{
- assert(p);
- delete p;
-}
-
-static const InspectApi kaizen_engine_api =
-{
- {
- PT_INSPECTOR,
- sizeof(InspectApi),
- INSAPI_VERSION,
- 0,
- API_RESERVED,
- API_OPTIONS,
- KZ_ENGINE_NAME,
- KZ_ENGINE_HELP,
- mod_ctor,
- mod_dtor
- },
- IT_PASSIVE,
- PROTO_BIT__NONE, // proto_bits;
- nullptr, // buffers
- nullptr, // service
- nullptr, // pinit
- nullptr, // pterm
- nullptr, // tinit
- nullptr, // tterm
- kaizen_engine_ctor,
- kaizen_engine_dtor,
- nullptr, // ssn
- nullptr // reset
-};
-
-#ifdef BUILDING_SO
-SO_PUBLIC const BaseApi* snort_plugins[] =
-#else
-const BaseApi* nin_kaizen_engine[] =
-#endif
-{
- &kaizen_engine_api.base,
- nullptr
-};
-
-#ifdef UNIT_TEST
-
-#include "catch/snort_catch.h"
-
-#include <memory.h>
-
-TEST_CASE("Kaizen tuner name", "[kaizen_module]")
-{
- const string http_param_model("model");
- KaizenReloadTuner tuner(http_param_model);
-
- REQUIRE(strcmp(tuner.name(), "KaizenReloadTuner") == 0);
-}
-
-#endif
extern const BaseApi* nin_extractor[];
#if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK)
-extern const BaseApi* nin_kaizen_engine[];
-extern const BaseApi* nin_kaizen[];
+extern const BaseApi* nin_snort_ml_engine[];
+extern const BaseApi* nin_snort_ml[];
#endif
extern const BaseApi* nin_port_scan[];
PluginManager::load_plugins(nin_extractor);
#if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK)
- PluginManager::load_plugins(nin_kaizen_engine);
- PluginManager::load_plugins(nin_kaizen);
+ PluginManager::load_plugins(nin_snort_ml_engine);
+ PluginManager::load_plugins(nin_snort_ml);
#endif
PluginManager::load_plugins(nin_port_scan);
--- /dev/null
+add_library(snort_ml OBJECT
+ snort_ml_engine.cc
+ snort_ml_engine.h
+ snort_ml_inspector.cc
+ snort_ml_inspector.h
+ snort_ml_module.cc
+ snort_ml_module.h
+)
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// snort_ml_engine.cc author Vitalii Horbatov <vhorbato@cisco.com>
+// author Brandon Stultz <brastult@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "snort_ml_engine.h"
+
+#include <cassert>
+#include <fstream>
+
+#ifdef HAVE_LIBML
+#include <libml.h>
+#endif
+
+#include "framework/decode_data.h"
+#include "helpers/directory.h"
+#include "log/messages.h"
+#include "main/reload_tuner.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+#include "parser/parse_conf.h"
+#include "utils/util.h"
+
+using namespace snort;
+using namespace std;
+
+static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr;
+
+static bool build_classifiers(const vector<string>& models,
+ libml::BinaryClassifierSet*& set)
+{
+ set = new libml::BinaryClassifierSet();
+
+ return set->build(models);
+}
+
+//--------------------------------------------------------------------------
+// module
+//--------------------------------------------------------------------------
+
+static const Parameter snort_ml_engine_params[] =
+{
+ { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" },
+ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {}
+
+bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*)
+{
+ if (v.is("http_param_model"))
+ conf.http_param_model_path = v.get_string();
+
+ return true;
+}
+
+//--------------------------------------------------------------------------
+// reload tuner
+//--------------------------------------------------------------------------
+
+class SnortMLReloadTuner : public snort::ReloadResourceTuner
+{
+public:
+ explicit SnortMLReloadTuner(const vector<string>& models)
+ : http_param_models(models) {}
+
+ ~SnortMLReloadTuner() override = default;
+
+ const char* name() const override
+ { return "SnortMLReloadTuner"; }
+
+ bool tinit() override
+ {
+ delete classifiers;
+
+ if (!build_classifiers(http_param_models, classifiers))
+ ErrorMessage("Could not build classifiers.\n");
+
+ return false;
+ }
+
+ bool tune_packet_context() override
+ { return true; }
+
+ bool tune_idle_context() override
+ { return true; }
+
+private:
+ const vector<string>& http_param_models;
+};
+
+//--------------------------------------------------------------------------
+// inspector
+//--------------------------------------------------------------------------
+
+SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c)
+{
+ if (!read_models() || !validate_models())
+ ParseError("Could not build classifiers.");
+}
+
+void SnortMLEngine::show(const SnortConfig*) const
+{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); }
+
+bool SnortMLEngine::read_models()
+{
+ const char* hint = config.http_param_model_path.c_str();
+ string path;
+
+ if (!get_config_file(hint, path))
+ {
+ ParseError("snort_ml_engine: could not read model file(s): %s", hint);
+ return false;
+ }
+
+ if (!is_directory_path(path))
+ {
+ if (!read_model(path))
+ {
+ ParseError("snort_ml_engine: could not read model file: %s", path.c_str());
+ return false;
+ }
+
+ return true;
+ }
+
+ Directory model_dir(path.c_str());
+
+ if (model_dir.error_on_open())
+ {
+ ParseError("snort_ml_engine: could not read model dir: %s", path.c_str());
+ return false;
+ }
+
+ while (const char* f = model_dir.next())
+ {
+ if (!read_model(f))
+ {
+ ParseError("snort_ml_engine: could not read model: %s", f);
+ return false;
+ }
+ }
+
+ return !http_param_models.empty();
+}
+
+bool SnortMLEngine::read_model(const string& path)
+{
+ size_t size = 0;
+
+ if (!get_file_size(path, size))
+ return false;
+
+ ifstream file(path, ios::binary);
+
+ if (!file.is_open() || size == 0)
+ return false;
+
+ string buffer(size, '\0');
+ file.read(&buffer[0], streamsize(size));
+
+ http_param_models.push_back(move(buffer));
+ return true;
+}
+
+bool SnortMLEngine::validate_models()
+{
+ libml::BinaryClassifierSet* set = nullptr;
+ bool res = build_classifiers(http_param_models, set);
+ delete set;
+
+ return res;
+}
+
+void SnortMLEngine::tinit()
+{ build_classifiers(http_param_models, classifiers); }
+
+void SnortMLEngine::tterm()
+{
+ delete classifiers;
+ classifiers = nullptr;
+}
+
+void SnortMLEngine::install_reload_handler(SnortConfig* sc)
+{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); }
+
+libml::BinaryClassifierSet* SnortMLEngine::get_classifiers()
+{ return classifiers; }
+
+//--------------------------------------------------------------------------
+// api stuff
+//--------------------------------------------------------------------------
+
+static Module* mod_ctor()
+{ return new SnortMLEngineModule; }
+
+static void mod_dtor(Module* m)
+{ delete m; }
+
+static Inspector* snort_ml_engine_ctor(Module* m)
+{
+ SnortMLEngineModule* mod = (SnortMLEngineModule*)m;
+ return new SnortMLEngine(mod->get_config());
+}
+
+static void snort_ml_engine_dtor(Inspector* p)
+{
+ assert(p);
+ delete p;
+}
+
+static const InspectApi snort_ml_engine_api =
+{
+ {
+ PT_INSPECTOR,
+ sizeof(InspectApi),
+ INSAPI_VERSION,
+ 0,
+ API_RESERVED,
+ API_OPTIONS,
+ SNORT_ML_ENGINE_NAME,
+ SNORT_ML_ENGINE_HELP,
+ mod_ctor,
+ mod_dtor
+ },
+ IT_PASSIVE,
+ PROTO_BIT__NONE, // proto_bits;
+ nullptr, // buffers
+ nullptr, // service
+ nullptr, // pinit
+ nullptr, // pterm
+ nullptr, // tinit
+ nullptr, // tterm
+ snort_ml_engine_ctor,
+ snort_ml_engine_dtor,
+ nullptr, // ssn
+ nullptr // reset
+};
+
+#ifdef BUILDING_SO
+SO_PUBLIC const BaseApi* snort_plugins[] =
+#else
+const BaseApi* nin_snort_ml_engine[] =
+#endif
+{
+ &snort_ml_engine_api.base,
+ nullptr
+};
+
+#ifdef UNIT_TEST
+
+#include "catch/snort_catch.h"
+
+#include <memory.h>
+
+TEST_CASE("SnortML tuner name", "[snort_ml_module]")
+{
+ const vector<string> models = { "model" };
+ SnortMLReloadTuner tuner(models);
+
+ REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0);
+}
+
+#endif
//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License Version 2 as published
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//--------------------------------------------------------------------------
-// kaizen_engine.h author Vitalii Horbatov <vhorbato@cisco.com>
+// snort_ml_engine.h author Vitalii Horbatov <vhorbato@cisco.com>
+// author Brandon Stultz <brastult@cisco.com>
-#ifndef KAIZEN_ENGINE_H
-#define KAIZEN_ENGINE_H
+#ifndef SNORT_ML_ENGINE_H
+#define SNORT_ML_ENGINE_H
#include "framework/module.h"
#include "framework/inspector.h"
+#define SNORT_ML_ENGINE_NAME "snort_ml_engine"
+#define SNORT_ML_ENGINE_HELP "configure machine learning engine settings"
-#define KZ_ENGINE_NAME "snort_ml_engine"
-#define KZ_ENGINE_HELP "configure machine learning engine settings"
+namespace libml
+{
+ class BinaryClassifierSet;
+}
-class BinaryClassifier;
-struct KaizenEngineConfig
+struct SnortMLEngineConfig
{
std::string http_param_model_path;
};
-class KaizenEngineModule : public snort::Module
+class SnortMLEngineModule : public snort::Module
{
public:
- KaizenEngineModule();
+ SnortMLEngineModule();
bool set(const char*, snort::Value&, snort::SnortConfig*) override;
Usage get_usage() const override
{ return GLOBAL; }
- const KaizenEngineConfig& get_config()
+ const SnortMLEngineConfig& get_config()
{ return conf; }
private:
- KaizenEngineConfig conf;
+ SnortMLEngineConfig conf;
};
-
-class KaizenEngine : public snort::Inspector
+class SnortMLEngine : public snort::Inspector
{
public:
- KaizenEngine(const KaizenEngineConfig&);
+ SnortMLEngine(const SnortMLEngineConfig&);
void show(const snort::SnortConfig*) const override;
void eval(snort::Packet*) override {}
void install_reload_handler(snort::SnortConfig*) override;
- static BinaryClassifier* get_classifier();
+ static libml::BinaryClassifierSet* get_classifiers();
private:
- std::string read_model();
- bool validate_model();
+ bool read_models();
+ bool read_model(const std::string&);
- KaizenEngineConfig config;
- std::string http_param_model;
-};
+ bool validate_models();
+ SnortMLEngineConfig config;
+ std::vector<std::string> http_param_models;
+};
-// Mock Classifier for tests if LibML absents.
-// However, when REG_TEST is undefined, the entire code below won't be executed.
-// Check the plugin type provided in kaizen_engine_api in the cc file
+// Mock BinaryClassifierSet for tests if LibML is absent.
+// The code below won't be executed if REG_TEST is undefined.
+// Check the plugin type provided in the snort_ml_engine.cc file.
#ifndef HAVE_LIBML
-class BinaryClassifier
+namespace libml
+{
+
+class BinaryClassifierSet
{
public:
- bool build(const std::string& model)
+ bool build(const std::vector<std::string>& models)
{
- pattern = model;
+ if (!models.empty())
+ pattern = models[0];
+
return pattern != "error";
}
- bool run(const char* ptr, size_t len, float& threshold)
+ bool run(const char* ptr, size_t len, float& out)
{
std::string data(ptr, len);
- threshold = std::string::npos == data.find(pattern) ? 0.0f : 1.0f;
+ out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f;
return pattern != "fail";
}
private:
std::string pattern;
};
+
+}
#endif
#endif
//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License Version 2 as published
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//--------------------------------------------------------------------------
-// kaizen_inspector.cc author Brandon Stultz <brastult@cisco.com>
+// snort_ml_inspector.cc author Brandon Stultz <brastult@cisco.com>
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
-#include "kaizen_inspector.h"
+#include "snort_ml_inspector.h"
#include <cassert>
#include "pub_sub/http_request_body_event.h"
#include "utils/util.h"
-#include "kaizen_engine.h"
+#include "snort_ml_engine.h"
using namespace snort;
using namespace std;
-THREAD_LOCAL KaizenStats kaizen_stats;
-THREAD_LOCAL ProfileStats kaizen_prof;
+THREAD_LOCAL SnortMLStats snort_ml_stats;
+THREAD_LOCAL ProfileStats snort_ml_prof;
//--------------------------------------------------------------------------
// HTTP body event handler
class HttpBodyHandler : public DataHandler
{
public:
- HttpBodyHandler(Kaizen& kz)
- : DataHandler(KZ_NAME), inspector(kz) {}
+ HttpBodyHandler(SnortML& ml)
+ : DataHandler(SNORT_ML_NAME), inspector(ml) {}
void handle(DataEvent& de, Flow*) override;
private:
- Kaizen& inspector;
+ SnortML& inspector;
};
void HttpBodyHandler::handle(DataEvent& de, Flow*)
{
// cppcheck-suppress unreadVariable
- Profile profile(kaizen_prof);
+ Profile profile(snort_ml_prof);
- BinaryClassifier* classifier = KaizenEngine::get_classifier();
- KaizenConfig config = inspector.get_config();
+ libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
+ SnortMLConfig config = inspector.get_config();
HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de;
if (he->is_mime())
const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len);
- assert(classifier);
+ assert(classifiers);
float output = 0.0;
- kaizen_stats.libml_calls++;
+ snort_ml_stats.libml_calls++;
- if (!classifier->run(body, len, output))
+ if (!classifiers->run(body, len, output))
return;
- kaizen_stats.client_body_bytes += len;
+ snort_ml_stats.client_body_bytes += len;
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body);
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
if ((double)output > config.http_param_threshold)
{
- kaizen_stats.client_body_alerts++;
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
- DetectionEngine::queue_event(KZ_GID, KZ_SID);
+ snort_ml_stats.client_body_alerts++;
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
+ DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID);
}
}
class HttpUriHandler : public DataHandler
{
public:
- HttpUriHandler(Kaizen& kz)
- : DataHandler(KZ_NAME), inspector(kz) {}
+ HttpUriHandler(SnortML& ml)
+ : DataHandler(SNORT_ML_NAME), inspector(ml) {}
void handle(DataEvent&, Flow*) override;
private:
- Kaizen& inspector;
+ SnortML& inspector;
};
void HttpUriHandler::handle(DataEvent& de, Flow*)
{
// cppcheck-suppress unreadVariable
- Profile profile(kaizen_prof);
+ Profile profile(snort_ml_prof);
- BinaryClassifier* classifier = KaizenEngine::get_classifier();
- KaizenConfig config = inspector.get_config();
+ libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers();
+ SnortMLConfig config = inspector.get_config();
HttpEvent* he = (HttpEvent*)&de;
int32_t query_len = 0;
const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len);
- assert(classifier);
+ assert(classifiers);
float output = 0.0;
- kaizen_stats.libml_calls++;
+ snort_ml_stats.libml_calls++;
- if (!classifier->run(query, (size_t)len, output))
+ if (!classifiers->run(query, len, output))
return;
- kaizen_stats.uri_bytes += len;
+ snort_ml_stats.uri_bytes += len;
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query);
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast<double>(output));
if ((double)output > config.http_param_threshold)
{
- kaizen_stats.uri_alerts++;
- debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
- DetectionEngine::queue_event(KZ_GID, KZ_SID);
+ snort_ml_stats.uri_alerts++;
+ debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "<ALERT>\n");
+ DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID);
}
}
// inspector
//--------------------------------------------------------------------------
-void Kaizen::show(const SnortConfig*) const
+void SnortML::show(const SnortConfig*) const
{
ConfigLogger::log_limit("uri_depth", config.uri_depth, -1);
ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1);
ConfigLogger::log_value("http_param_threshold", config.http_param_threshold);
}
-bool Kaizen::configure(SnortConfig* sc)
+bool SnortML::configure(SnortConfig* sc)
{
if (config.uri_depth != 0)
DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this));
if (config.client_body_depth != 0)
DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this));
- if(!InspectorManager::get_inspector(KZ_ENGINE_NAME, true, sc))
+ if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc))
{
- ParseError("snort_ml requires %s to be configured in the global policy.", KZ_ENGINE_NAME);
+ ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME);
return false;
}
//--------------------------------------------------------------------------
static Module* mod_ctor()
-{ return new KaizenModule; }
+{ return new SnortMLModule; }
static void mod_dtor(Module* m)
{ delete m; }
-static Inspector* kaizen_ctor(Module* m)
+static Inspector* snort_ml_ctor(Module* m)
{
- KaizenModule* km = (KaizenModule*)m;
- return new Kaizen(km->get_conf());
+ SnortMLModule* km = (SnortMLModule*)m;
+ return new SnortML(km->get_conf());
}
-static void kaizen_dtor(Inspector* p)
+static void snort_ml_dtor(Inspector* p)
{
assert(p);
delete p;
}
-static const InspectApi kaizen_api =
+static const InspectApi snort_ml_api =
{
{
PT_INSPECTOR,
0,
API_RESERVED,
API_OPTIONS,
- KZ_NAME,
- KZ_HELP,
+ SNORT_ML_NAME,
+ SNORT_ML_HELP,
mod_ctor,
mod_dtor
},
nullptr, // pterm
nullptr, // tinit
nullptr, // tterm
- kaizen_ctor,
- kaizen_dtor,
+ snort_ml_ctor,
+ snort_ml_dtor,
nullptr, // ssn
nullptr // reset
};
#ifdef BUILDING_SO
SO_PUBLIC const BaseApi* snort_plugins[] =
#else
-const BaseApi* nin_kaizen[] =
+const BaseApi* nin_snort_ml[] =
#endif
{
- &kaizen_api.base,
+ &snort_ml_api.base,
nullptr
};
//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License Version 2 as published
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//--------------------------------------------------------------------------
-// kaizen_inspector.h author Brandon Stultz <brastult@cisco.com>
+// snort_ml_inspector.h author Brandon Stultz <brastult@cisco.com>
-#ifndef KAIZEN_INSPECTOR_H
-#define KAIZEN_INSPECTOR_H
+#ifndef SNORT_ML_INSPECTOR_H
+#define SNORT_ML_INSPECTOR_H
#include <string>
#include <utility>
#include "framework/inspector.h"
-#include "kaizen_module.h"
+#include "snort_ml_module.h"
-class Kaizen : public snort::Inspector
+class SnortML : public snort::Inspector
{
public:
- Kaizen(const KaizenConfig& c) : config(c) { }
+ SnortML(const SnortMLConfig& c) : config(c) { }
void show(const snort::SnortConfig*) const override;
void eval(snort::Packet*) override {}
bool configure(snort::SnortConfig*) override;
- const KaizenConfig& get_config()
+ const SnortMLConfig& get_config()
{ return config; }
private:
- KaizenConfig config;
+ SnortMLConfig config;
};
#endif
//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License Version 2 as published
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//--------------------------------------------------------------------------
-// kaizen_module.cc author Brandon Stultz <brastult@cisco.com>
+// snort_ml_module.cc author Brandon Stultz <brastult@cisco.com>
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
-#include "kaizen_module.h"
+#include "snort_ml_module.h"
#include "log/messages.h"
#include "service_inspectors/http_inspect/http_field.h"
using namespace snort;
-THREAD_LOCAL const Trace* kaizen_trace = nullptr;
+THREAD_LOCAL const Trace* snort_ml_trace = nullptr;
-static const Parameter kaizen_params[] =
+static const Parameter snort_ml_params[] =
{
{ "uri_depth", Parameter::PT_INT, "-1:max31", "-1",
"number of input HTTP URI bytes to scan (-1 unlimited)" },
{ nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
};
-static const RuleMap kaizen_rules[] =
+static const RuleMap snort_ml_rules[] =
{
- { KZ_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" },
+ { SNORT_ML_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" },
{ 0, nullptr }
};
};
#ifdef DEBUG_MSGS
-static const TraceOption kaizen_trace_options[] =
+static const TraceOption snort_ml_trace_options[] =
{
{ "classifier", TRACE_CLASSIFIER, "enable Snort ML classifier trace logging" },
{ nullptr, 0, nullptr }
// module
//--------------------------------------------------------------------------
-KaizenModule::KaizenModule() : Module(KZ_NAME, KZ_HELP, kaizen_params) {}
+SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {}
-bool KaizenModule::set(const char*, Value& v, SnortConfig*)
+bool SnortMLModule::set(const char*, Value& v, SnortConfig*)
{
static_assert(std::is_same<decltype((Field().length())), decltype(conf.uri_depth)>::value,
"Field::length maximum value should not exceed uri_depth type range");
return true;
}
-bool KaizenModule::end(const char*, int, snort::SnortConfig*)
+bool SnortMLModule::end(const char*, int, snort::SnortConfig*)
{
if (!conf.uri_depth && !conf.client_body_depth)
ParseWarning(WARN_CONF,
return true;
}
-const RuleMap* KaizenModule::get_rules() const
-{ return kaizen_rules; }
+const RuleMap* SnortMLModule::get_rules() const
+{ return snort_ml_rules; }
-const PegInfo* KaizenModule::get_pegs() const
+const PegInfo* SnortMLModule::get_pegs() const
{ return peg_names; }
-PegCount* KaizenModule::get_counts() const
-{ return (PegCount*)&kaizen_stats; }
+PegCount* SnortMLModule::get_counts() const
+{ return (PegCount*)&snort_ml_stats; }
-ProfileStats* KaizenModule::get_profile() const
-{ return &kaizen_prof; }
+ProfileStats* SnortMLModule::get_profile() const
+{ return &snort_ml_prof; }
-void KaizenModule::set_trace(const Trace* trace) const
-{ kaizen_trace = trace; }
+void SnortMLModule::set_trace(const Trace* trace) const
+{ snort_ml_trace = trace; }
-const TraceOption* KaizenModule::get_trace_options() const
+const TraceOption* SnortMLModule::get_trace_options() const
{
#ifndef DEBUG_MSGS
return nullptr;
#else
- return kaizen_trace_options;
+ return snort_ml_trace_options;
#endif
}
//--------------------------------------------------------------------------
-// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved.
+// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved.
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License Version 2 as published
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//--------------------------------------------------------------------------
-// kaizen_module.h author Brandon Stultz <brastult@cisco.com>
+// snort_ml_module.h author Brandon Stultz <brastult@cisco.com>
-#ifndef KAIZEN_MODULE_H
-#define KAIZEN_MODULE_H
+#ifndef SNORT_ML_MODULE_H
+#define SNORT_ML_MODULE_H
#include "framework/module.h"
#include "main/thread.h"
#include "profiler/profiler.h"
#include "trace/trace_api.h"
-#define KZ_GID 411
-#define KZ_SID 1
+#define SNORT_ML_GID 411
+#define SNORT_ML_SID 1
-#define KZ_NAME "snort_ml"
-#define KZ_HELP "machine learning based exploit detector"
+#define SNORT_ML_NAME "snort_ml"
+#define SNORT_ML_HELP "machine learning based exploit detector"
enum { TRACE_CLASSIFIER };
-struct KaizenStats
+struct SnortMLStats
{
PegCount uri_alerts;
PegCount client_body_alerts;
PegCount libml_calls;
};
-extern THREAD_LOCAL KaizenStats kaizen_stats;
-extern THREAD_LOCAL snort::ProfileStats kaizen_prof;
-extern THREAD_LOCAL const snort::Trace* kaizen_trace;
+extern THREAD_LOCAL SnortMLStats snort_ml_stats;
+extern THREAD_LOCAL snort::ProfileStats snort_ml_prof;
+extern THREAD_LOCAL const snort::Trace* snort_ml_trace;
-struct KaizenConfig
+struct SnortMLConfig
{
std::string http_param_model_path;
double http_param_threshold;
int32_t client_body_depth;
};
-class KaizenModule : public snort::Module
+class SnortMLModule : public snort::Module
{
public:
- KaizenModule();
+ SnortMLModule();
bool set(const char*, snort::Value&, snort::SnortConfig*) override;
bool end(const char*, int, snort::SnortConfig*) override;
- const KaizenConfig& get_conf() const
+ const SnortMLConfig& get_conf() const
{ return conf; }
unsigned get_gid() const override
- { return KZ_GID; }
+ { return SNORT_ML_GID; }
const snort::RuleMap* get_rules() const override;
const snort::TraceOption* get_trace_options() const override;
private:
- KaizenConfig conf = {};
+ SnortMLConfig conf = {};
};
#endif
#ifdef _WIN32
#define ISREG(m) (((m) & _S_IFMT) == _S_IFREG)
+#define ISDIR(m) (((m) & _S_IFMT) == _S_IFDIR)
#else
#define ISREG(m) S_ISREG(m)
+#define ISDIR(m) S_ISDIR(m)
#endif
using namespace snort;
return true;
}
+bool is_directory_path(const std::string& path)
+{
+ struct STAT sb;
+
+ if (STAT(path.c_str(), &sb))
+ return false;
+
+ return ISDIR(sb.st_mode);
+}
+
namespace snort
{
const char* get_error(int errnum)
unsigned int get_random_seed();
bool get_file_size(const std::string&, size_t&);
+bool is_directory_path(const std::string&);
namespace
{