From: Brandon Stultz (brastult) Date: Tue, 11 Feb 2025 09:28:46 +0000 (+0000) Subject: Pull request #4595: snort_ml: build models into a BinaryClassifierSet X-Git-Tag: 3.7.1.0~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a1dda4afc853e167d3bb62a8c382476a5d56a840;p=thirdparty%2Fsnort3.git Pull request #4595: snort_ml: build models into a BinaryClassifierSet Merge in SNORT/snort3 from ~BRASTULT/snort3:snort_ml to master Squashed commit of the following: commit e4f35d63b7bc2fa38176408466afe8576d0f77f0 Author: Brandon Stultz Date: Fri Jan 31 02:43:25 2025 -0500 snort_ml: build models into a BinaryClassifierSet commit 7ac7827b65192d6319893498585b48e0c7809e1b Author: Brandon Stultz Date: Fri Jan 31 01:16:00 2025 -0500 utils: add is_directory_path commit e3897fe6bf08d2fba2406f612b4bf3b31e07cfea Author: Brandon Stultz Date: Thu Jan 30 11:57:53 2025 -0500 network_inspectors: rename kaizen to snort_ml --- diff --git a/cmake/FindML.cmake b/cmake/FindML.cmake index 0282bcbef..c30553c95 100644 --- a/cmake/FindML.cmake +++ b/cmake/FindML.cmake @@ -1,12 +1,22 @@ -find_library(ML_LIBRARIES NAMES ml_static HINTS ${ML_LIBRARIES_DIR_HINT}) -find_path(ML_INCLUDE_DIRS libml.h HINTS ${ML_INCLUDE_DIR_HINT}) +find_package(PkgConfig) +pkg_check_modules(PC_ML libml_static>=2.0.0) + +find_path(ML_INCLUDE_DIRS + libml.h + HINTS ${ML_INCLUDE_DIR_HINT} ${PC_ML_INCLUDEDIR} +) + +find_library(ML_LIBRARIES + NAMES ml_static + HINTS ${ML_LIBRARIES_DIR_HINT} ${PC_ML_LIBDIR} +) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(ML DEFAULT_MSG - ML_LIBRARIES ML_INCLUDE_DIRS + ML_LIBRARIES ) if (ML_FOUND AND NOT USE_LIBML_MOCK) diff --git a/src/main/process.cc b/src/main/process.cc index 2b1f94fe4..f361a5962 100644 --- a/src/main/process.cc +++ b/src/main/process.cc @@ -45,6 +45,10 @@ #include #endif +#ifdef HAVE_LIBML +#include +#endif + #ifdef HAVE_LIBUNWIND #define UNW_LOCAL_ONLY #include @@ -706,6 +710,9 @@ int DisplayBanner() size_t sz = sizeof(jv); mallctl("version", &jv, &sz, NULL, 0); LogMessage(" Using Jemalloc version %s\n", jv); +#endif +#ifdef HAVE_LIBML + LogMessage(" Using LibML version %s\n", libml::version()); #endif LogMessage(" Using %s\n", pcap_lib_version()); LogMessage(" Using LuaJIT version %s\n", ljv); diff --git a/src/main/shell.cc b/src/main/shell.cc index bb7b7134c..21d1b7537 100644 --- a/src/main/shell.cc +++ b/src/main/shell.cc @@ -171,7 +171,7 @@ static void install_dependencies_strings(Shell* sh, lua_State* L) vs.push_back(lzma_version_string()); #endif #ifdef HAVE_LIBML - vs.push_back(libml_version()); + vs.push_back(libml::version()); #endif lua_createtable(L, 0, vs.size()); diff --git a/src/network_inspectors/CMakeLists.txt b/src/network_inspectors/CMakeLists.txt index fd8694b75..81ea889bd 100644 --- a/src/network_inspectors/CMakeLists.txt +++ b/src/network_inspectors/CMakeLists.txt @@ -5,7 +5,7 @@ add_subdirectory(binder) add_subdirectory(extractor) if ( HAVE_LIBML OR USE_LIBML_MOCK ) - add_subdirectory(kaizen) + add_subdirectory(snort_ml) endif() add_subdirectory(normalize) @@ -24,8 +24,8 @@ if(STATIC_INSPECTORS) endif() if ( HAVE_LIBML OR USE_LIBML_MOCK ) - set(KAIZEN_STATIC_OBJ - $ + set(SNORT_ML_STATIC_OBJ + $ ) endif() @@ -33,7 +33,7 @@ set(STATIC_NETWORK_INSPECTOR_PLUGINS $ $ $ - ${KAIZEN_STATIC_OBJ} + ${SNORT_ML_STATIC_OBJ} $ $ $ diff --git a/src/network_inspectors/dev_notes.txt b/src/network_inspectors/dev_notes.txt index fabcd45f8..a3cf640b2 100644 --- a/src/network_inspectors/dev_notes.txt +++ b/src/network_inspectors/dev_notes.txt @@ -20,9 +20,9 @@ normalizations. packet_capture - A tool for dumping the wire packets that Snort receives. -kaizen - Machine learning based exploit detector capable of detecting novel -attacks fitting known vulnerability types. Kaizen uses a neural network -provided by a model file to detect exploit patterns. The Kaizen Snort module +snort_ml - Machine learning based exploit detector capable of detecting novel +attacks fitting known vulnerability types. SnortML uses a neural network +provided by a model file to detect exploit patterns. The SnortML module subscribes to HTTP events published by the HTTP inspector, performs inference on HTTP queries/posts, and generates events if the neural network detects an exploit. diff --git a/src/network_inspectors/kaizen/CMakeLists.txt b/src/network_inspectors/kaizen/CMakeLists.txt deleted file mode 100644 index bc590578d..000000000 --- a/src/network_inspectors/kaizen/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_library(kaizen OBJECT - kaizen_engine.cc - kaizen_engine.h - kaizen_inspector.cc - kaizen_inspector.h - kaizen_module.cc - kaizen_module.h -) diff --git a/src/network_inspectors/kaizen/kaizen_engine.cc b/src/network_inspectors/kaizen/kaizen_engine.cc deleted file mode 100644 index 6e60a7762..000000000 --- a/src/network_inspectors/kaizen/kaizen_engine.cc +++ /dev/null @@ -1,250 +0,0 @@ -//-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. -// -// This program is free software; you can redistribute it and/or modify it -// under the terms of the GNU General Public License Version 2 as published -// by the Free Software Foundation. You may not use, modify or distribute -// this program under any other version of the GNU General Public License. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -// General Public License for more details. -// -// You should have received a copy of the GNU General Public License along -// with this program; if not, write to the Free Software Foundation, Inc., -// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -//-------------------------------------------------------------------------- -// kaizen_engine.cc author Vitalii Horbatov - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "kaizen_engine.h" - -#include -#include - -#ifdef HAVE_LIBML -#include -#endif - -#include "framework/decode_data.h" -#include "log/messages.h" -#include "main/reload_tuner.h" -#include "main/snort.h" -#include "main/snort_config.h" -#include "parser/parse_conf.h" -#include "utils/util.h" - -using namespace snort; -using namespace std; - -static THREAD_LOCAL BinaryClassifier* classifier = nullptr; - -static bool build_classifier(const string& model, BinaryClassifier*& dst) -{ - dst = new BinaryClassifier(); - - return dst->build(model); -} - -//-------------------------------------------------------------------------- -// module -//-------------------------------------------------------------------------- - -static const Parameter kaizen_engine_params[] = -{ - { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to the model file" }, - { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } -}; - -KaizenEngineModule::KaizenEngineModule() : Module(KZ_ENGINE_NAME, KZ_ENGINE_HELP, kaizen_engine_params) {} - -bool KaizenEngineModule::set(const char*, Value& v, SnortConfig*) -{ - if (v.is("http_param_model")) - conf.http_param_model_path = v.get_string(); - - return true; -} - -//-------------------------------------------------------------------------- -// reload tuner -//-------------------------------------------------------------------------- - -class KaizenReloadTuner : public snort::ReloadResourceTuner -{ -public: - explicit KaizenReloadTuner(const string& http_param_model) : http_param_model(http_param_model) {} - ~KaizenReloadTuner() override = default; - - const char* name() const override - { return "KaizenReloadTuner"; } - - bool tinit() override - { - delete classifier; - - if (!build_classifier(http_param_model, classifier)) - ErrorMessage("Can't build the classifier model.\n"); - - return false; - } - - bool tune_packet_context() override - { return true; } - - bool tune_idle_context() override - { return true; } - -private: - const string& http_param_model; -}; - -//-------------------------------------------------------------------------- -// inspector -//-------------------------------------------------------------------------- - -KaizenEngine::KaizenEngine(const KaizenEngineConfig& c) : config(c) -{ - http_param_model = read_model(); - - if (!validate_model()) - ParseError("Can't build the classifier model %s.", config.http_param_model_path.c_str()); -} - -void KaizenEngine::show(const SnortConfig*) const -{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); } - -string KaizenEngine::read_model() -{ - const char* hint = config.http_param_model_path.c_str(); - string path; - size_t size = 0; - - if (!get_config_file(hint, path) || !get_file_size(path, size)) - { - ParseError("snort_ml_engine: could not read model file: %s", hint); - return {}; - } - - ifstream file(path, ios::binary); - - if (!file.is_open()) - { - ParseError("snort_ml_engine: could not read model file: %s", hint); - return {}; - } - - if (size == 0) - { - ParseError("snort_ml_engine: empty model file: %s", hint); - return {}; - } - - string buffer(size, '\0'); - file.read(&buffer[0], streamsize(size)); - return buffer; -} - -bool KaizenEngine::validate_model() -{ - BinaryClassifier* test_classifier = nullptr; - bool res = build_classifier(http_param_model, test_classifier); - delete test_classifier; - - return res; -} - -void KaizenEngine::tinit() -{ build_classifier(http_param_model, classifier); } - -void KaizenEngine::tterm() -{ - delete classifier; - classifier = nullptr; -} - -void KaizenEngine::install_reload_handler(SnortConfig* sc) -{ sc->register_reload_handler(new KaizenReloadTuner(http_param_model)); } - -BinaryClassifier* KaizenEngine::get_classifier() -{ return classifier; } - -//-------------------------------------------------------------------------- -// api stuff -//-------------------------------------------------------------------------- - -static Module* mod_ctor() -{ return new KaizenEngineModule; } - -static void mod_dtor(Module* m) -{ delete m; } - -static Inspector* kaizen_engine_ctor(Module* m) -{ - KaizenEngineModule* mod = (KaizenEngineModule*)m; - return new KaizenEngine(mod->get_config()); -} - -static void kaizen_engine_dtor(Inspector* p) -{ - assert(p); - delete p; -} - -static const InspectApi kaizen_engine_api = -{ - { - PT_INSPECTOR, - sizeof(InspectApi), - INSAPI_VERSION, - 0, - API_RESERVED, - API_OPTIONS, - KZ_ENGINE_NAME, - KZ_ENGINE_HELP, - mod_ctor, - mod_dtor - }, - IT_PASSIVE, - PROTO_BIT__NONE, // proto_bits; - nullptr, // buffers - nullptr, // service - nullptr, // pinit - nullptr, // pterm - nullptr, // tinit - nullptr, // tterm - kaizen_engine_ctor, - kaizen_engine_dtor, - nullptr, // ssn - nullptr // reset -}; - -#ifdef BUILDING_SO -SO_PUBLIC const BaseApi* snort_plugins[] = -#else -const BaseApi* nin_kaizen_engine[] = -#endif -{ - &kaizen_engine_api.base, - nullptr -}; - -#ifdef UNIT_TEST - -#include "catch/snort_catch.h" - -#include - -TEST_CASE("Kaizen tuner name", "[kaizen_module]") -{ - const string http_param_model("model"); - KaizenReloadTuner tuner(http_param_model); - - REQUIRE(strcmp(tuner.name(), "KaizenReloadTuner") == 0); -} - -#endif diff --git a/src/network_inspectors/network_inspectors.cc b/src/network_inspectors/network_inspectors.cc index 7abdadda5..57808b251 100644 --- a/src/network_inspectors/network_inspectors.cc +++ b/src/network_inspectors/network_inspectors.cc @@ -34,8 +34,8 @@ extern const BaseApi* nin_appid[]; extern const BaseApi* nin_extractor[]; #if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK) -extern const BaseApi* nin_kaizen_engine[]; -extern const BaseApi* nin_kaizen[]; +extern const BaseApi* nin_snort_ml_engine[]; +extern const BaseApi* nin_snort_ml[]; #endif extern const BaseApi* nin_port_scan[]; @@ -62,8 +62,8 @@ void load_network_inspectors() PluginManager::load_plugins(nin_extractor); #if defined(HAVE_LIBML) || defined(USE_LIBML_MOCK) - PluginManager::load_plugins(nin_kaizen_engine); - PluginManager::load_plugins(nin_kaizen); + PluginManager::load_plugins(nin_snort_ml_engine); + PluginManager::load_plugins(nin_snort_ml); #endif PluginManager::load_plugins(nin_port_scan); diff --git a/src/network_inspectors/snort_ml/CMakeLists.txt b/src/network_inspectors/snort_ml/CMakeLists.txt new file mode 100644 index 000000000..cf0e26016 --- /dev/null +++ b/src/network_inspectors/snort_ml/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(snort_ml OBJECT + snort_ml_engine.cc + snort_ml_engine.h + snort_ml_inspector.cc + snort_ml_inspector.h + snort_ml_module.cc + snort_ml_module.h +) diff --git a/src/network_inspectors/kaizen/dev_notes.txt b/src/network_inspectors/snort_ml/dev_notes.txt similarity index 100% rename from src/network_inspectors/kaizen/dev_notes.txt rename to src/network_inspectors/snort_ml/dev_notes.txt diff --git a/src/network_inspectors/snort_ml/snort_ml_engine.cc b/src/network_inspectors/snort_ml/snort_ml_engine.cc new file mode 100644 index 000000000..6b3c63c96 --- /dev/null +++ b/src/network_inspectors/snort_ml/snort_ml_engine.cc @@ -0,0 +1,283 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// snort_ml_engine.cc author Vitalii Horbatov +// author Brandon Stultz + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "snort_ml_engine.h" + +#include +#include + +#ifdef HAVE_LIBML +#include +#endif + +#include "framework/decode_data.h" +#include "helpers/directory.h" +#include "log/messages.h" +#include "main/reload_tuner.h" +#include "main/snort.h" +#include "main/snort_config.h" +#include "parser/parse_conf.h" +#include "utils/util.h" + +using namespace snort; +using namespace std; + +static THREAD_LOCAL libml::BinaryClassifierSet* classifiers = nullptr; + +static bool build_classifiers(const vector& models, + libml::BinaryClassifierSet*& set) +{ + set = new libml::BinaryClassifierSet(); + + return set->build(models); +} + +//-------------------------------------------------------------------------- +// module +//-------------------------------------------------------------------------- + +static const Parameter snort_ml_engine_params[] = +{ + { "http_param_model", Parameter::PT_STRING, nullptr, nullptr, "path to model file(s)" }, + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + +SnortMLEngineModule::SnortMLEngineModule() : Module(SNORT_ML_ENGINE_NAME, SNORT_ML_ENGINE_HELP, snort_ml_engine_params) {} + +bool SnortMLEngineModule::set(const char*, Value& v, SnortConfig*) +{ + if (v.is("http_param_model")) + conf.http_param_model_path = v.get_string(); + + return true; +} + +//-------------------------------------------------------------------------- +// reload tuner +//-------------------------------------------------------------------------- + +class SnortMLReloadTuner : public snort::ReloadResourceTuner +{ +public: + explicit SnortMLReloadTuner(const vector& models) + : http_param_models(models) {} + + ~SnortMLReloadTuner() override = default; + + const char* name() const override + { return "SnortMLReloadTuner"; } + + bool tinit() override + { + delete classifiers; + + if (!build_classifiers(http_param_models, classifiers)) + ErrorMessage("Could not build classifiers.\n"); + + return false; + } + + bool tune_packet_context() override + { return true; } + + bool tune_idle_context() override + { return true; } + +private: + const vector& http_param_models; +}; + +//-------------------------------------------------------------------------- +// inspector +//-------------------------------------------------------------------------- + +SnortMLEngine::SnortMLEngine(const SnortMLEngineConfig& c) : config(c) +{ + if (!read_models() || !validate_models()) + ParseError("Could not build classifiers."); +} + +void SnortMLEngine::show(const SnortConfig*) const +{ ConfigLogger::log_value("http_param_model", config.http_param_model_path.c_str()); } + +bool SnortMLEngine::read_models() +{ + const char* hint = config.http_param_model_path.c_str(); + string path; + + if (!get_config_file(hint, path)) + { + ParseError("snort_ml_engine: could not read model file(s): %s", hint); + return false; + } + + if (!is_directory_path(path)) + { + if (!read_model(path)) + { + ParseError("snort_ml_engine: could not read model file: %s", path.c_str()); + return false; + } + + return true; + } + + Directory model_dir(path.c_str()); + + if (model_dir.error_on_open()) + { + ParseError("snort_ml_engine: could not read model dir: %s", path.c_str()); + return false; + } + + while (const char* f = model_dir.next()) + { + if (!read_model(f)) + { + ParseError("snort_ml_engine: could not read model: %s", f); + return false; + } + } + + return !http_param_models.empty(); +} + +bool SnortMLEngine::read_model(const string& path) +{ + size_t size = 0; + + if (!get_file_size(path, size)) + return false; + + ifstream file(path, ios::binary); + + if (!file.is_open() || size == 0) + return false; + + string buffer(size, '\0'); + file.read(&buffer[0], streamsize(size)); + + http_param_models.push_back(move(buffer)); + return true; +} + +bool SnortMLEngine::validate_models() +{ + libml::BinaryClassifierSet* set = nullptr; + bool res = build_classifiers(http_param_models, set); + delete set; + + return res; +} + +void SnortMLEngine::tinit() +{ build_classifiers(http_param_models, classifiers); } + +void SnortMLEngine::tterm() +{ + delete classifiers; + classifiers = nullptr; +} + +void SnortMLEngine::install_reload_handler(SnortConfig* sc) +{ sc->register_reload_handler(new SnortMLReloadTuner(http_param_models)); } + +libml::BinaryClassifierSet* SnortMLEngine::get_classifiers() +{ return classifiers; } + +//-------------------------------------------------------------------------- +// api stuff +//-------------------------------------------------------------------------- + +static Module* mod_ctor() +{ return new SnortMLEngineModule; } + +static void mod_dtor(Module* m) +{ delete m; } + +static Inspector* snort_ml_engine_ctor(Module* m) +{ + SnortMLEngineModule* mod = (SnortMLEngineModule*)m; + return new SnortMLEngine(mod->get_config()); +} + +static void snort_ml_engine_dtor(Inspector* p) +{ + assert(p); + delete p; +} + +static const InspectApi snort_ml_engine_api = +{ + { + PT_INSPECTOR, + sizeof(InspectApi), + INSAPI_VERSION, + 0, + API_RESERVED, + API_OPTIONS, + SNORT_ML_ENGINE_NAME, + SNORT_ML_ENGINE_HELP, + mod_ctor, + mod_dtor + }, + IT_PASSIVE, + PROTO_BIT__NONE, // proto_bits; + nullptr, // buffers + nullptr, // service + nullptr, // pinit + nullptr, // pterm + nullptr, // tinit + nullptr, // tterm + snort_ml_engine_ctor, + snort_ml_engine_dtor, + nullptr, // ssn + nullptr // reset +}; + +#ifdef BUILDING_SO +SO_PUBLIC const BaseApi* snort_plugins[] = +#else +const BaseApi* nin_snort_ml_engine[] = +#endif +{ + &snort_ml_engine_api.base, + nullptr +}; + +#ifdef UNIT_TEST + +#include "catch/snort_catch.h" + +#include + +TEST_CASE("SnortML tuner name", "[snort_ml_module]") +{ + const vector models = { "model" }; + SnortMLReloadTuner tuner(models); + + REQUIRE(strcmp(tuner.name(), "SnortMLReloadTuner") == 0); +} + +#endif diff --git a/src/network_inspectors/kaizen/kaizen_engine.h b/src/network_inspectors/snort_ml/snort_ml_engine.h similarity index 54% rename from src/network_inspectors/kaizen/kaizen_engine.h rename to src/network_inspectors/snort_ml/snort_ml_engine.h index 7c7381ba5..0d77d1664 100644 --- a/src/network_inspectors/kaizen/kaizen_engine.h +++ b/src/network_inspectors/snort_ml/snort_ml_engine.h @@ -1,5 +1,5 @@ //-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. // // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU General Public License Version 2 as published @@ -15,46 +15,49 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// kaizen_engine.h author Vitalii Horbatov +// snort_ml_engine.h author Vitalii Horbatov +// author Brandon Stultz -#ifndef KAIZEN_ENGINE_H -#define KAIZEN_ENGINE_H +#ifndef SNORT_ML_ENGINE_H +#define SNORT_ML_ENGINE_H #include "framework/module.h" #include "framework/inspector.h" +#define SNORT_ML_ENGINE_NAME "snort_ml_engine" +#define SNORT_ML_ENGINE_HELP "configure machine learning engine settings" -#define KZ_ENGINE_NAME "snort_ml_engine" -#define KZ_ENGINE_HELP "configure machine learning engine settings" +namespace libml +{ + class BinaryClassifierSet; +} -class BinaryClassifier; -struct KaizenEngineConfig +struct SnortMLEngineConfig { std::string http_param_model_path; }; -class KaizenEngineModule : public snort::Module +class SnortMLEngineModule : public snort::Module { public: - KaizenEngineModule(); + SnortMLEngineModule(); bool set(const char*, snort::Value&, snort::SnortConfig*) override; Usage get_usage() const override { return GLOBAL; } - const KaizenEngineConfig& get_config() + const SnortMLEngineConfig& get_config() { return conf; } private: - KaizenEngineConfig conf; + SnortMLEngineConfig conf; }; - -class KaizenEngine : public snort::Inspector +class SnortMLEngine : public snort::Inspector { public: - KaizenEngine(const KaizenEngineConfig&); + SnortMLEngine(const SnortMLEngineConfig&); void show(const snort::SnortConfig*) const override; void eval(snort::Packet*) override {} @@ -64,40 +67,48 @@ public: void install_reload_handler(snort::SnortConfig*) override; - static BinaryClassifier* get_classifier(); + static libml::BinaryClassifierSet* get_classifiers(); private: - std::string read_model(); - bool validate_model(); + bool read_models(); + bool read_model(const std::string&); - KaizenEngineConfig config; - std::string http_param_model; -}; + bool validate_models(); + SnortMLEngineConfig config; + std::vector http_param_models; +}; -// Mock Classifier for tests if LibML absents. -// However, when REG_TEST is undefined, the entire code below won't be executed. -// Check the plugin type provided in kaizen_engine_api in the cc file +// Mock BinaryClassifierSet for tests if LibML is absent. +// The code below won't be executed if REG_TEST is undefined. +// Check the plugin type provided in the snort_ml_engine.cc file. #ifndef HAVE_LIBML -class BinaryClassifier +namespace libml +{ + +class BinaryClassifierSet { public: - bool build(const std::string& model) + bool build(const std::vector& models) { - pattern = model; + if (!models.empty()) + pattern = models[0]; + return pattern != "error"; } - bool run(const char* ptr, size_t len, float& threshold) + bool run(const char* ptr, size_t len, float& out) { std::string data(ptr, len); - threshold = std::string::npos == data.find(pattern) ? 0.0f : 1.0f; + out = data.find(pattern) == std::string::npos ? 0.0f : 1.0f; return pattern != "fail"; } private: std::string pattern; }; + +} #endif #endif diff --git a/src/network_inspectors/kaizen/kaizen_inspector.cc b/src/network_inspectors/snort_ml/snort_ml_inspector.cc similarity index 64% rename from src/network_inspectors/kaizen/kaizen_inspector.cc rename to src/network_inspectors/snort_ml/snort_ml_inspector.cc index 863eb59f3..9d27a5817 100644 --- a/src/network_inspectors/kaizen/kaizen_inspector.cc +++ b/src/network_inspectors/snort_ml/snort_ml_inspector.cc @@ -1,5 +1,5 @@ //-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. // // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU General Public License Version 2 as published @@ -15,13 +15,13 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// kaizen_inspector.cc author Brandon Stultz +// snort_ml_inspector.cc author Brandon Stultz #ifdef HAVE_CONFIG_H #include "config.h" #endif -#include "kaizen_inspector.h" +#include "snort_ml_inspector.h" #include @@ -36,13 +36,13 @@ #include "pub_sub/http_request_body_event.h" #include "utils/util.h" -#include "kaizen_engine.h" +#include "snort_ml_engine.h" using namespace snort; using namespace std; -THREAD_LOCAL KaizenStats kaizen_stats; -THREAD_LOCAL ProfileStats kaizen_prof; +THREAD_LOCAL SnortMLStats snort_ml_stats; +THREAD_LOCAL ProfileStats snort_ml_prof; //-------------------------------------------------------------------------- // HTTP body event handler @@ -51,22 +51,22 @@ THREAD_LOCAL ProfileStats kaizen_prof; class HttpBodyHandler : public DataHandler { public: - HttpBodyHandler(Kaizen& kz) - : DataHandler(KZ_NAME), inspector(kz) {} + HttpBodyHandler(SnortML& ml) + : DataHandler(SNORT_ML_NAME), inspector(ml) {} void handle(DataEvent& de, Flow*) override; private: - Kaizen& inspector; + SnortML& inspector; }; void HttpBodyHandler::handle(DataEvent& de, Flow*) { // cppcheck-suppress unreadVariable - Profile profile(kaizen_prof); + Profile profile(snort_ml_prof); - BinaryClassifier* classifier = KaizenEngine::get_classifier(); - KaizenConfig config = inspector.get_config(); + libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers(); + SnortMLConfig config = inspector.get_config(); HttpRequestBodyEvent* he = (HttpRequestBodyEvent*)&de; if (he->is_mime()) @@ -80,25 +80,25 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*) const size_t len = std::min((size_t)config.client_body_depth, (size_t)body_len); - assert(classifier); + assert(classifiers); float output = 0.0; - kaizen_stats.libml_calls++; + snort_ml_stats.libml_calls++; - if (!classifier->run(body, len, output)) + if (!classifiers->run(body, len, output)) return; - kaizen_stats.client_body_bytes += len; + snort_ml_stats.client_body_bytes += len; - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body); - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (body): %.*s\n", (int)len, body); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); if ((double)output > config.http_param_threshold) { - kaizen_stats.client_body_alerts++; - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "\n"); - DetectionEngine::queue_event(KZ_GID, KZ_SID); + snort_ml_stats.client_body_alerts++; + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "\n"); + DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID); } } @@ -109,22 +109,22 @@ void HttpBodyHandler::handle(DataEvent& de, Flow*) class HttpUriHandler : public DataHandler { public: - HttpUriHandler(Kaizen& kz) - : DataHandler(KZ_NAME), inspector(kz) {} + HttpUriHandler(SnortML& ml) + : DataHandler(SNORT_ML_NAME), inspector(ml) {} void handle(DataEvent&, Flow*) override; private: - Kaizen& inspector; + SnortML& inspector; }; void HttpUriHandler::handle(DataEvent& de, Flow*) { // cppcheck-suppress unreadVariable - Profile profile(kaizen_prof); + Profile profile(snort_ml_prof); - BinaryClassifier* classifier = KaizenEngine::get_classifier(); - KaizenConfig config = inspector.get_config(); + libml::BinaryClassifierSet* classifiers = SnortMLEngine::get_classifiers(); + SnortMLConfig config = inspector.get_config(); HttpEvent* he = (HttpEvent*)&de; int32_t query_len = 0; @@ -135,25 +135,25 @@ void HttpUriHandler::handle(DataEvent& de, Flow*) const size_t len = std::min((size_t)config.uri_depth, (size_t)query_len); - assert(classifier); + assert(classifiers); float output = 0.0; - kaizen_stats.libml_calls++; + snort_ml_stats.libml_calls++; - if (!classifier->run(query, (size_t)len, output)) + if (!classifiers->run(query, len, output)) return; - kaizen_stats.uri_bytes += len; + snort_ml_stats.uri_bytes += len; - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query); - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "input (query): %.*s\n", (int)len, query); + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "output: %f\n", static_cast(output)); if ((double)output > config.http_param_threshold) { - kaizen_stats.uri_alerts++; - debug_logf(kaizen_trace, TRACE_CLASSIFIER, nullptr, "\n"); - DetectionEngine::queue_event(KZ_GID, KZ_SID); + snort_ml_stats.uri_alerts++; + debug_logf(snort_ml_trace, TRACE_CLASSIFIER, nullptr, "\n"); + DetectionEngine::queue_event(SNORT_ML_GID, SNORT_ML_SID); } } @@ -161,14 +161,14 @@ void HttpUriHandler::handle(DataEvent& de, Flow*) // inspector //-------------------------------------------------------------------------- -void Kaizen::show(const SnortConfig*) const +void SnortML::show(const SnortConfig*) const { ConfigLogger::log_limit("uri_depth", config.uri_depth, -1); ConfigLogger::log_limit("client_body_depth", config.client_body_depth, -1); ConfigLogger::log_value("http_param_threshold", config.http_param_threshold); } -bool Kaizen::configure(SnortConfig* sc) +bool SnortML::configure(SnortConfig* sc) { if (config.uri_depth != 0) DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_HEADER, new HttpUriHandler(*this)); @@ -176,9 +176,9 @@ bool Kaizen::configure(SnortConfig* sc) if (config.client_body_depth != 0) DataBus::subscribe(http_pub_key, HttpEventIds::REQUEST_BODY, new HttpBodyHandler(*this)); - if(!InspectorManager::get_inspector(KZ_ENGINE_NAME, true, sc)) + if(!InspectorManager::get_inspector(SNORT_ML_ENGINE_NAME, true, sc)) { - ParseError("snort_ml requires %s to be configured in the global policy.", KZ_ENGINE_NAME); + ParseError("snort_ml requires %s to be configured in the global policy.", SNORT_ML_ENGINE_NAME); return false; } @@ -190,24 +190,24 @@ bool Kaizen::configure(SnortConfig* sc) //-------------------------------------------------------------------------- static Module* mod_ctor() -{ return new KaizenModule; } +{ return new SnortMLModule; } static void mod_dtor(Module* m) { delete m; } -static Inspector* kaizen_ctor(Module* m) +static Inspector* snort_ml_ctor(Module* m) { - KaizenModule* km = (KaizenModule*)m; - return new Kaizen(km->get_conf()); + SnortMLModule* km = (SnortMLModule*)m; + return new SnortML(km->get_conf()); } -static void kaizen_dtor(Inspector* p) +static void snort_ml_dtor(Inspector* p) { assert(p); delete p; } -static const InspectApi kaizen_api = +static const InspectApi snort_ml_api = { { PT_INSPECTOR, @@ -216,8 +216,8 @@ static const InspectApi kaizen_api = 0, API_RESERVED, API_OPTIONS, - KZ_NAME, - KZ_HELP, + SNORT_ML_NAME, + SNORT_ML_HELP, mod_ctor, mod_dtor }, @@ -229,8 +229,8 @@ static const InspectApi kaizen_api = nullptr, // pterm nullptr, // tinit nullptr, // tterm - kaizen_ctor, - kaizen_dtor, + snort_ml_ctor, + snort_ml_dtor, nullptr, // ssn nullptr // reset }; @@ -238,9 +238,9 @@ static const InspectApi kaizen_api = #ifdef BUILDING_SO SO_PUBLIC const BaseApi* snort_plugins[] = #else -const BaseApi* nin_kaizen[] = +const BaseApi* nin_snort_ml[] = #endif { - &kaizen_api.base, + &snort_ml_api.base, nullptr }; diff --git a/src/network_inspectors/kaizen/kaizen_inspector.h b/src/network_inspectors/snort_ml/snort_ml_inspector.h similarity index 76% rename from src/network_inspectors/kaizen/kaizen_inspector.h rename to src/network_inspectors/snort_ml/snort_ml_inspector.h index d54d72a62..f7f97fca9 100644 --- a/src/network_inspectors/kaizen/kaizen_inspector.h +++ b/src/network_inspectors/snort_ml/snort_ml_inspector.h @@ -1,5 +1,5 @@ //-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. // // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU General Public License Version 2 as published @@ -15,31 +15,31 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// kaizen_inspector.h author Brandon Stultz +// snort_ml_inspector.h author Brandon Stultz -#ifndef KAIZEN_INSPECTOR_H -#define KAIZEN_INSPECTOR_H +#ifndef SNORT_ML_INSPECTOR_H +#define SNORT_ML_INSPECTOR_H #include #include #include "framework/inspector.h" -#include "kaizen_module.h" +#include "snort_ml_module.h" -class Kaizen : public snort::Inspector +class SnortML : public snort::Inspector { public: - Kaizen(const KaizenConfig& c) : config(c) { } + SnortML(const SnortMLConfig& c) : config(c) { } void show(const snort::SnortConfig*) const override; void eval(snort::Packet*) override {} bool configure(snort::SnortConfig*) override; - const KaizenConfig& get_config() + const SnortMLConfig& get_config() { return config; } private: - KaizenConfig config; + SnortMLConfig config; }; #endif diff --git a/src/network_inspectors/kaizen/kaizen_module.cc b/src/network_inspectors/snort_ml/snort_ml_module.cc similarity index 74% rename from src/network_inspectors/kaizen/kaizen_module.cc rename to src/network_inspectors/snort_ml/snort_ml_module.cc index 554b8ed41..d0c4cd291 100644 --- a/src/network_inspectors/kaizen/kaizen_module.cc +++ b/src/network_inspectors/snort_ml/snort_ml_module.cc @@ -1,5 +1,5 @@ //-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. // // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU General Public License Version 2 as published @@ -15,22 +15,22 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// kaizen_module.cc author Brandon Stultz +// snort_ml_module.cc author Brandon Stultz #ifdef HAVE_CONFIG_H #include "config.h" #endif -#include "kaizen_module.h" +#include "snort_ml_module.h" #include "log/messages.h" #include "service_inspectors/http_inspect/http_field.h" using namespace snort; -THREAD_LOCAL const Trace* kaizen_trace = nullptr; +THREAD_LOCAL const Trace* snort_ml_trace = nullptr; -static const Parameter kaizen_params[] = +static const Parameter snort_ml_params[] = { { "uri_depth", Parameter::PT_INT, "-1:max31", "-1", "number of input HTTP URI bytes to scan (-1 unlimited)" }, @@ -44,9 +44,9 @@ static const Parameter kaizen_params[] = { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; -static const RuleMap kaizen_rules[] = +static const RuleMap snort_ml_rules[] = { - { KZ_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" }, + { SNORT_ML_SID, "potential threat found in HTTP parameters via Neural Network Based Exploit Detection" }, { 0, nullptr } }; @@ -61,7 +61,7 @@ static const PegInfo peg_names[] = }; #ifdef DEBUG_MSGS -static const TraceOption kaizen_trace_options[] = +static const TraceOption snort_ml_trace_options[] = { { "classifier", TRACE_CLASSIFIER, "enable Snort ML classifier trace logging" }, { nullptr, 0, nullptr } @@ -72,9 +72,9 @@ static const TraceOption kaizen_trace_options[] = // module //-------------------------------------------------------------------------- -KaizenModule::KaizenModule() : Module(KZ_NAME, KZ_HELP, kaizen_params) {} +SnortMLModule::SnortMLModule() : Module(SNORT_ML_NAME, SNORT_ML_HELP, snort_ml_params) {} -bool KaizenModule::set(const char*, Value& v, SnortConfig*) +bool SnortMLModule::set(const char*, Value& v, SnortConfig*) { static_assert(std::is_same::value, "Field::length maximum value should not exceed uri_depth type range"); @@ -91,7 +91,7 @@ bool KaizenModule::set(const char*, Value& v, SnortConfig*) return true; } -bool KaizenModule::end(const char*, int, snort::SnortConfig*) +bool SnortMLModule::end(const char*, int, snort::SnortConfig*) { if (!conf.uri_depth && !conf.client_body_depth) ParseWarning(WARN_CONF, @@ -100,26 +100,26 @@ bool KaizenModule::end(const char*, int, snort::SnortConfig*) return true; } -const RuleMap* KaizenModule::get_rules() const -{ return kaizen_rules; } +const RuleMap* SnortMLModule::get_rules() const +{ return snort_ml_rules; } -const PegInfo* KaizenModule::get_pegs() const +const PegInfo* SnortMLModule::get_pegs() const { return peg_names; } -PegCount* KaizenModule::get_counts() const -{ return (PegCount*)&kaizen_stats; } +PegCount* SnortMLModule::get_counts() const +{ return (PegCount*)&snort_ml_stats; } -ProfileStats* KaizenModule::get_profile() const -{ return &kaizen_prof; } +ProfileStats* SnortMLModule::get_profile() const +{ return &snort_ml_prof; } -void KaizenModule::set_trace(const Trace* trace) const -{ kaizen_trace = trace; } +void SnortMLModule::set_trace(const Trace* trace) const +{ snort_ml_trace = trace; } -const TraceOption* KaizenModule::get_trace_options() const +const TraceOption* SnortMLModule::get_trace_options() const { #ifndef DEBUG_MSGS return nullptr; #else - return kaizen_trace_options; + return snort_ml_trace_options; #endif } diff --git a/src/network_inspectors/kaizen/kaizen_module.h b/src/network_inspectors/snort_ml/snort_ml_module.h similarity index 74% rename from src/network_inspectors/kaizen/kaizen_module.h rename to src/network_inspectors/snort_ml/snort_ml_module.h index 59624bddf..b842a06ec 100644 --- a/src/network_inspectors/kaizen/kaizen_module.h +++ b/src/network_inspectors/snort_ml/snort_ml_module.h @@ -1,5 +1,5 @@ //-------------------------------------------------------------------------- -// Copyright (C) 2023-2024 Cisco and/or its affiliates. All rights reserved. +// Copyright (C) 2023-2025 Cisco and/or its affiliates. All rights reserved. // // This program is free software; you can redistribute it and/or modify it // under the terms of the GNU General Public License Version 2 as published @@ -15,25 +15,25 @@ // with this program; if not, write to the Free Software Foundation, Inc., // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. //-------------------------------------------------------------------------- -// kaizen_module.h author Brandon Stultz +// snort_ml_module.h author Brandon Stultz -#ifndef KAIZEN_MODULE_H -#define KAIZEN_MODULE_H +#ifndef SNORT_ML_MODULE_H +#define SNORT_ML_MODULE_H #include "framework/module.h" #include "main/thread.h" #include "profiler/profiler.h" #include "trace/trace_api.h" -#define KZ_GID 411 -#define KZ_SID 1 +#define SNORT_ML_GID 411 +#define SNORT_ML_SID 1 -#define KZ_NAME "snort_ml" -#define KZ_HELP "machine learning based exploit detector" +#define SNORT_ML_NAME "snort_ml" +#define SNORT_ML_HELP "machine learning based exploit detector" enum { TRACE_CLASSIFIER }; -struct KaizenStats +struct SnortMLStats { PegCount uri_alerts; PegCount client_body_alerts; @@ -42,11 +42,11 @@ struct KaizenStats PegCount libml_calls; }; -extern THREAD_LOCAL KaizenStats kaizen_stats; -extern THREAD_LOCAL snort::ProfileStats kaizen_prof; -extern THREAD_LOCAL const snort::Trace* kaizen_trace; +extern THREAD_LOCAL SnortMLStats snort_ml_stats; +extern THREAD_LOCAL snort::ProfileStats snort_ml_prof; +extern THREAD_LOCAL const snort::Trace* snort_ml_trace; -struct KaizenConfig +struct SnortMLConfig { std::string http_param_model_path; double http_param_threshold; @@ -54,19 +54,19 @@ struct KaizenConfig int32_t client_body_depth; }; -class KaizenModule : public snort::Module +class SnortMLModule : public snort::Module { public: - KaizenModule(); + SnortMLModule(); bool set(const char*, snort::Value&, snort::SnortConfig*) override; bool end(const char*, int, snort::SnortConfig*) override; - const KaizenConfig& get_conf() const + const SnortMLConfig& get_conf() const { return conf; } unsigned get_gid() const override - { return KZ_GID; } + { return SNORT_ML_GID; } const snort::RuleMap* get_rules() const override; @@ -82,7 +82,7 @@ public: const snort::TraceOption* get_trace_options() const override; private: - KaizenConfig conf = {}; + SnortMLConfig conf = {}; }; #endif diff --git a/src/utils/util.cc b/src/utils/util.cc index 5175ebb0c..ad6c2d27e 100644 --- a/src/utils/util.cc +++ b/src/utils/util.cc @@ -47,8 +47,10 @@ #ifdef _WIN32 #define ISREG(m) (((m) & _S_IFMT) == _S_IFREG) +#define ISDIR(m) (((m) & _S_IFMT) == _S_IFDIR) #else #define ISREG(m) S_ISREG(m) +#define ISDIR(m) S_ISDIR(m) #endif using namespace snort; @@ -83,6 +85,16 @@ bool get_file_size(const std::string& path, size_t& size) return true; } +bool is_directory_path(const std::string& path) +{ + struct STAT sb; + + if (STAT(path.c_str(), &sb)) + return false; + + return ISDIR(sb.st_mode); +} + namespace snort { const char* get_error(int errnum) diff --git a/src/utils/util.h b/src/utils/util.h index c6916d6c0..9f91be245 100644 --- a/src/utils/util.h +++ b/src/utils/util.h @@ -65,6 +65,7 @@ unsigned int get_random_seed(); bool get_file_size(const std::string&, size_t&); +bool is_directory_path(const std::string&); namespace {