]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4895: appid: add setUserDetectorDataItem lua detector API
authorOleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) <ostepano@cisco.com>
Thu, 18 Sep 2025 19:27:45 +0000 (19:27 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Thu, 18 Sep 2025 19:27:45 +0000 (19:27 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:user_data_lua to master

Squashed commit of the following:

commit 37c1d2245679348f43b571307d9bb50a4ae96e91
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Thu Sep 4 10:34:36 2025 -0400

    appid: add setUserDetectorDataItem lua detector API

src/network_inspectors/appid/app_info_table.cc
src/network_inspectors/appid/lua_detector_api.cc
src/network_inspectors/appid/lua_detector_module.cc
src/network_inspectors/appid/test/CMakeLists.txt
src/network_inspectors/appid/test/user_data_map_test.cc [new file with mode: 0644]
src/network_inspectors/appid/user_data_map.cc
src/network_inspectors/appid/user_data_map.h

index 10e83554b844f9cc79aa396e8772268d485d1f70..15977471a9b41242367e71258325ee0a31a7e801 100644 (file)
@@ -721,7 +721,7 @@ void AppInfoManager::load_odp_config(OdpContext& odp_ctxt, const char* path)
                 const std::string user_table(conf_val);
                 const std::string user_key(token);
                 const std::string user_value(token2);
-                odp_ctxt.get_user_data_map().add_user_data(user_table, user_key, user_value);
+                odp_ctxt.get_user_data_map().add_user_data(user_table, user_key, user_value, false);
             }
             else
                 ParseWarning(WARN_CONF, "appid: unsupported configuration: %s\n", conf_key);
index bc28f9302a0feccf16d27f1c3e9b37e18215fb2f..017631e349511c15f3368e5d62f76d1edbf23eb5 100644 (file)
@@ -3344,6 +3344,37 @@ static int get_user_detector_data_item(lua_State *L)
     return 1;
 }
 
+static int set_user_detector_data_item(lua_State *L)
+{
+    auto& ud = *UserData<LuaObject>::check(L, DETECTOR, 1);
+    const char* table = lua_tostring(L, 2);
+    if (!table)
+    {
+        APPID_LOG(nullptr, TRACE_ERROR_LEVEL, "appid: Invalid detector data table string in %s.\n",
+            ud->get_detector()->get_name().c_str());
+        return 0;
+    }
+    const char* key = lua_tostring(L, 3);
+    if (!key)
+    {
+        APPID_LOG(nullptr, TRACE_ERROR_LEVEL, "appid: Invalid detector data key string in %s.\n",
+            ud->get_detector()->get_name().c_str());
+        return 0;
+    }
+
+    const char* item = lua_tostring(L, 4);
+    if (!item)
+    {
+        APPID_LOG(nullptr, TRACE_ERROR_LEVEL, "appid: Invalid detector data item string in %s.\n",
+            ud->get_detector()->get_name().c_str());
+        return 0;
+    }
+
+    int result = ud->get_odp_ctxt().get_user_data_map().add_user_data(table, key, item, true) ? 1 : 0;
+
+    return result;
+}
+
 static const luaL_Reg detector_methods[] =
 {
     /* Obsolete API names.  No longer use these!  They are here for backward
@@ -3466,6 +3497,7 @@ static const luaL_Reg detector_methods[] =
     { "getHttpTunneledPort",      get_http_tunneled_port },
 
     { "getUserDetectorDataItem",   get_user_detector_data_item },
+    { "setUserDetectorDataItem",   set_user_detector_data_item },
 
      /* CIP registration */
     {"addCipConnectionClass",    detector_add_cip_connection_class},
index 6f3f5c4ecd96c7f50cd9efa935f31474940ea3d5..8d891139217aaeb3dc6aacf50100f3fcb59e52c7 100644 (file)
@@ -227,6 +227,7 @@ LuaDetectorManager::~LuaDetectorManager()
 
 void LuaDetectorManager::initialize(const SnortConfig* sc)
 {
+    ctxt.get_odp_ctxt().get_user_data_map().set_configuration_completed(false);
     activate_lua_detectors(sc);
     
     if (SnortConfig::log_verbose())
@@ -515,6 +516,7 @@ void LuaDetectorManager::activate_lua_detectors(const SnortConfig* sc)
         lua_settop(L, 0);
         ++lo;
     }
+    ctxt.get_odp_ctxt().get_user_data_map().set_configuration_completed(true);
 }
 void ControlLuaDetectorManager::process_detector_file(char* detector_file_path, bool is_custom)
 {
index aac3507935bd11cab122dbc64aa8f75def248f38..77dbe7495fa8e3c21dfd7c8bef242c0ac06389d8 100644 (file)
@@ -48,6 +48,10 @@ add_cpputest( appid_eve_process_event_handler_test
     SOURCES $<TARGET_OBJECTS:appid_cpputest_deps>
 )
 
+add_cpputest( user_data_map_test
+    SOURCES ../user_data_map.cc
+)
+
 add_cpputest( tp_lib_handler_test
     SOURCES
         tp_lib_handler_test.cc
diff --git a/src/network_inspectors/appid/test/user_data_map_test.cc b/src/network_inspectors/appid/test/user_data_map_test.cc
new file mode 100644 (file)
index 0000000..125d8fe
--- /dev/null
@@ -0,0 +1,169 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2018-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// user_data_map_test.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "main/thread.h"
+
+#include "../user_data_map.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+static uint appid_log_call_count = 0;
+void appid_log(const snort::Packet*, unsigned char, char const*, ...) { appid_log_call_count++; }
+
+SThreadType test_thread_type = SThreadType::STHREAD_TYPE_MAIN;
+
+namespace snort
+{
+    SThreadType get_thread_type()
+    {
+        return test_thread_type;
+    }
+}
+
+TEST_GROUP(user_data_map_test)
+{
+    UserDataMap* test_user_data_map;
+
+    void setup() override
+    {
+        test_user_data_map = new UserDataMap();
+        test_user_data_map->set_configuration_completed(false);
+    }
+
+    void teardown() override
+    {
+        test_user_data_map->set_configuration_completed(false);
+        delete test_user_data_map;
+    }
+};
+
+TEST(user_data_map_test, add_and_get_user_data)
+{
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value = "test_value";
+
+    bool added = test_user_data_map->add_user_data(table, key, value);
+    CHECK_TRUE(added);
+
+    const char* retrieved_value = test_user_data_map->get_user_data_value_str(table, key);
+    STRCMP_EQUAL(value.c_str(), retrieved_value);
+}
+
+TEST(user_data_map_test, add_duplicate_key_without_override)
+{
+    appid_log_call_count = 0;
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value1 = "test_value1";
+    const std::string value2 = "test_value2";
+
+    bool first_add = test_user_data_map->add_user_data(table, key, value1);
+    CHECK_TRUE(first_add);
+
+    bool second_add = test_user_data_map->add_user_data(table, key, value2);
+    CHECK_EQUAL(1, appid_log_call_count);
+    CHECK_FALSE(second_add);
+
+    auto get_value = test_user_data_map->get_user_data_value_str(table, key);
+    STRCMP_EQUAL(get_value, value1.c_str());
+}
+
+TEST(user_data_map_test, add_duplicate_key_with_override)
+{
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value1 = "test_value1";
+    const std::string value2 = "test_value2";
+
+    bool first_add = test_user_data_map->add_user_data(table, key, value1);
+    CHECK_TRUE(first_add);
+
+    bool second_add = test_user_data_map->add_user_data(table, key, value2, true);
+    const char* retrieved_value = test_user_data_map->get_user_data_value_str(table, key);
+    CHECK_TRUE(second_add);
+    STRCMP_EQUAL(value2.c_str(), retrieved_value);
+}
+
+TEST(user_data_map_test, get_nonexistent_key)
+{
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value = "test_value";
+
+    test_user_data_map->add_user_data(table, key, value);
+
+    const char* retrieved_value = test_user_data_map->get_user_data_value_str(table, "some_other_key");
+    CHECK_TRUE(retrieved_value == nullptr);
+}
+
+TEST(user_data_map_test, get_from_nonexistent_table)
+{
+    const std::string table = "nonexistent_table";
+    const std::string key = "test_key";
+
+    const char* retrieved_value = test_user_data_map->get_user_data_value_str(table, key);
+    CHECK_TRUE(retrieved_value == nullptr);
+}
+
+TEST(user_data_map_test, add_user_data_from_non_main_thread_before_configuration_completed)
+{
+    test_thread_type = SThreadType::STHREAD_TYPE_PACKET;
+    appid_log_call_count = 0;
+
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value = "test_value";
+
+    bool added = test_user_data_map->add_user_data(table, key, value);
+    CHECK_FALSE(added);
+    CHECK_EQUAL(0, appid_log_call_count);
+
+    test_thread_type = SThreadType::STHREAD_TYPE_MAIN;
+}
+
+TEST(user_data_map_test, add_user_data_from_non_main_thread_after_configuration_completed)
+{
+    test_thread_type = SThreadType::STHREAD_TYPE_PACKET;
+    appid_log_call_count = 0;
+
+    const std::string table = "test_table";
+    const std::string key = "test_key";
+    const std::string value = "test_value";
+
+    test_user_data_map->set_configuration_completed(true);
+    bool added = test_user_data_map->add_user_data(table, key, value);
+    CHECK_FALSE(added);
+    CHECK_EQUAL(1, appid_log_call_count);
+
+    test_thread_type = SThreadType::STHREAD_TYPE_MAIN;
+}
+
+int main(int argc, char** argv)
+{
+    int rc = CommandLineTestRunner::RunAllTests(argc, argv);
+
+    return rc;
+}
index 0ef923446ac8b5bfd7258d40eab4b6271ac14256..c185ff862546287a5c0b3d3e5a03d327c699b819 100644 (file)
 
 #include "user_data_map.h"
 
+#include "main/thread.h"
+
+static THREAD_LOCAL bool configuration_completed;
+
 UserDataMap::~UserDataMap()
 {
     user_data_maps.clear();
 }
 
-void UserDataMap::add_user_data(const std::string& table, const std::string& key,
-    const std::string& item)
+bool UserDataMap::add_user_data(const std::string &table, const std::string &key,
+                                const std::string &item, bool override_existing)
 {
-    if (user_data_maps.find(table) != user_data_maps.end())
+
+    if (snort::get_thread_type() != SThreadType::STHREAD_TYPE_MAIN)
     {
-        if (user_data_maps[table].find(key) != user_data_maps[table].end())
-        {
-            APPID_LOG(nullptr, TRACE_WARNING_LEVEL,"ignoring duplicate key %s in table %s",
+        if (configuration_completed)
+            APPID_LOG(nullptr, TRACE_WARNING_LEVEL, "AppId: ignoring user data with key %s in table %s from non-main thread\n",
                 key.c_str(), table.c_str());
-            return;
+        return false;
+    }
+
+    auto table_it = user_data_maps.find(table);
+    if (table_it != user_data_maps.end())
+    {
+        if (override_existing)
+        {
+            table_it->second[key] = item;
+        }
+        else
+        {
+            auto insert_result = table_it->second.try_emplace(key, item);
+            if (insert_result.second == false)
+            {
+                APPID_LOG(nullptr, TRACE_WARNING_LEVEL, "AppId: ignoring duplicate key %s in table %s\n",
+                    key.c_str(), table.c_str());
+                return false;
+            }
         }
-        user_data_maps[table][key] = item;
     }
     else
     {
@@ -48,16 +69,27 @@ void UserDataMap::add_user_data(const std::string& table, const std::string& key
         user_map[key] = item;
         user_data_maps[table] = user_map;
     }
+
+    return true;
 }
 
 const char* UserDataMap::get_user_data_value_str(const std::string& table,
     const std::string& key)
 {
-    if (user_data_maps.find(table) != user_data_maps.end() and
-        user_data_maps[table].find(key) != user_data_maps[table].end())
+    auto table_it = user_data_maps.find(table);
+    if (table_it != user_data_maps.end())
     {
-        return user_data_maps[table][key].c_str();
+        auto key_it = table_it->second.find(key);
+        if (key_it != table_it->second.end())
+        {
+            return key_it->second.c_str();
+        }
     }
-    else
-        return nullptr;
+    
+    return nullptr;
+}
+
+void UserDataMap::set_configuration_completed(bool completed)
+{
+    configuration_completed = completed;
 }
index 7f4ef20fa3dddfd46a51e1206fe9bfdfc6d6110d..d2b2074622d02240c8224b5f1beb8ca5e37bb3a0 100644 (file)
@@ -23,9 +23,9 @@
 
 /* User Data Map uses an unordered map to store arbitrary user-defined key value pairs
  * used in lua detectors. Mappings are loaded from appid.conf or userappid.conf using a
- * key that is hardcoded in the detector. The user supplies the value. At runtime, if the lua
- * detector's conditions are met during validation, the lua detector can use its key to
- * retrieve the customer data.
+ * key that is hardcoded in the detector or loaded from lua detectors that utilize setUserDetectorDataItem API.
+ * The user supplies the value. At runtime, if the lua detector's conditions are met during validation,
+ * the lua detector can use its key to retrieve the customer data.
  */
 
 #include <string>
@@ -42,9 +42,11 @@ class UserDataMap
 {
 public:
     ~UserDataMap();
-    void add_user_data(const std::string& table, const std::string& key,
-        const std::string& item);
+    bool add_user_data(const std::string& table, const std::string& key,
+        const std::string& item, bool override_existing = false);
     const char* get_user_data_value_str(const std::string& table, const std::string& key);
+
+    void set_configuration_completed(bool completed);
 private:
     UserDataMaps user_data_maps;
 };