]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Use luaL_ref for URL rewriter callback instead of global function name
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 11 Oct 2025 14:18:02 +0000 (15:18 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 11 Oct 2025 14:18:02 +0000 (15:18 +0100)
src/libserver/html/html_url_rewrite.cxx
src/libserver/html/html_url_rewrite.hxx
src/libserver/html/html_url_rewrite_c.cxx
src/libserver/html/html_url_rewrite_c.h
src/lua/lua_task.c

index b958ceeeac79e66f7f16c5d1b40f2564b1c42c54..387490206eb5e1693c538601828abfd1b6a34cea 100644 (file)
@@ -36,19 +36,15 @@ namespace rspamd::html {
 /**
  * Call Lua url_rewriter function to get replacement URL
  * @param task Rspamd task
- * @param func_name Lua function name (e.g., "url_rewriter")
+ * @param L Lua state
+ * @param func_ref Lua function reference from luaL_ref
  * @param url Original URL string
  * @return Replacement URL or empty optional if no replacement
  */
-static auto call_lua_url_rewriter(struct rspamd_task *task, const char *func_name, const std::string &url)
+static auto call_lua_url_rewriter(struct rspamd_task *task, ::lua_State *L, int func_ref, const std::string &url)
        -> std::optional<std::string>
 {
-       if (!func_name || !task || !task->cfg) {
-               return std::nullopt;
-       }
-
-       auto *L = RSPAMD_LUA_CFG_STATE(task->cfg);
-       if (!L) {
+       if (!L || func_ref == LUA_NOREF || func_ref == LUA_REFNIL || !task) {
                return std::nullopt;
        }
 
@@ -56,9 +52,11 @@ static auto call_lua_url_rewriter(struct rspamd_task *task, const char *func_nam
        lua_pushcfunction(L, &rspamd_lua_traceback);
        auto err_idx = lua_gettop(L);
 
-       // Get the function
-       if (!rspamd_lua_require_function(L, func_name, nullptr)) {
-               msg_debug_html_rewrite("cannot require function %s", func_name);
+       // Get the function from registry
+       lua_rawgeti(L, LUA_REGISTRYINDEX, func_ref);
+
+       if (!lua_isfunction(L, -1)) {
+               msg_debug_html_rewrite("func_ref is not a function");
                lua_settop(L, err_idx - 1);
                return std::nullopt;
        }
@@ -73,7 +71,7 @@ static auto call_lua_url_rewriter(struct rspamd_task *task, const char *func_nam
 
        // Call function with 2 args, 1 result
        if (lua_pcall(L, 2, 1, err_idx) != 0) {
-               msg_warn_task("call to %s failed: %s", func_name, lua_tostring(L, -1));
+               msg_warn_task("call to url rewriter failed: %s", lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);
                return std::nullopt;
        }
@@ -89,7 +87,7 @@ static auto call_lua_url_rewriter(struct rspamd_task *task, const char *func_nam
                }
        }
        else if (!lua_isnil(L, -1)) {
-               msg_warn_task("%s returned non-string value", func_name);
+               msg_warn_task("url rewriter returned non-string value");
        }
 
        lua_settop(L, err_idx - 1);
@@ -206,13 +204,14 @@ auto apply_patches(std::string_view original, const std::vector<rewrite_patch> &
 }
 
 auto process_html_url_rewrite(struct rspamd_task *task,
+                                                         ::lua_State *L,
                                                          const html_content *hc,
-                                                         const char *func_name,
+                                                         int func_ref,
                                                          int part_id,
                                                          std::string_view original_html)
        -> std::optional<std::string>
 {
-       if (!task || !hc || !func_name) {
+       if (!task || !hc || !L || func_ref == LUA_NOREF || func_ref == LUA_REFNIL) {
                return std::nullopt;
        }
 
@@ -231,7 +230,7 @@ auto process_html_url_rewrite(struct rspamd_task *task,
 
        for (const auto &candidate: candidates) {
                // Call Lua callback
-               auto replacement = call_lua_url_rewriter(task, func_name, candidate.absolute_url);
+               auto replacement = call_lua_url_rewriter(task, L, func_ref, candidate.absolute_url);
                if (!replacement) {
                        continue;// Skip if Lua returned nil
                }
index 970f492bec23c8896efd5da7677396b4fa219255..7de79acc26161753550887813268c7a4cc173fa6 100644 (file)
@@ -25,6 +25,7 @@
 #include <optional>
 
 struct rspamd_task;
+struct lua_State;
 
 namespace rspamd::html {
 
@@ -94,15 +95,17 @@ auto apply_patches(std::string_view original, const std::vector<rewrite_patch> &
  * Process HTML URL rewriting for a task
  * Enumerates candidates, calls Lua callback, applies patches, and returns rewritten HTML
  * @param task Rspamd task
+ * @param L Lua state
  * @param hc HTML content
- * @param func_name Lua function name for URL rewriting
+ * @param func_ref Lua function reference from luaL_ref
  * @param part_id MIME part ID
  * @param original_html Original HTML content (decoded)
  * @return Rewritten HTML or nullopt if no changes
  */
 auto process_html_url_rewrite(struct rspamd_task *task,
+                                                         ::lua_State *L,
                                                          const html_content *hc,
-                                                         const char *func_name,
+                                                         int func_ref,
                                                          int part_id,
                                                          std::string_view original_html)
        -> std::optional<std::string>;
index 1a3606f59bbb0bfdf35c96564ecee8f25f35ae29..f3e66672fe8ce4036b01b62cd469e5e5f5ede0a0 100644 (file)
 extern "C" {
 
 int rspamd_html_url_rewrite(struct rspamd_task *task,
+                                                       struct lua_State *L,
                                                        void *html_content,
-                                                       const char *func_name,
+                                                       int func_ref,
                                                        int part_id,
                                                        const char *original_html,
                                                        gsize html_len,
                                                        char **output_html,
                                                        gsize *output_len)
 {
-       if (!task || !html_content || !func_name || !original_html) {
+       if (!task || !L || !html_content || !original_html) {
                return -1;
        }
 
@@ -38,7 +39,7 @@ int rspamd_html_url_rewrite(struct rspamd_task *task,
        std::string_view original{original_html, html_len};
 
        auto result = rspamd::html::process_html_url_rewrite(
-               task, hc, func_name, part_id, original);
+               task, L, hc, func_ref, part_id, original);
 
        if (!result) {
                return -1;
index 371b50346a33e3d1401f34d17e877eadeeedcbbd..798c8b39872f2f43517daa26e8c615a4051d49a1 100644 (file)
@@ -25,11 +25,14 @@ extern "C" {
 
 struct rspamd_task;
 
+struct lua_State;
+
 /**
  * C wrapper for HTML URL rewriting
  * @param task Rspamd task
+ * @param L Lua state
  * @param html_content HTML content pointer (void* cast of html_content*)
- * @param func_name Lua function name for rewriting
+ * @param func_ref Lua function reference (from luaL_ref)
  * @param part_id MIME part ID
  * @param original_html Original HTML content
  * @param html_len Length of original HTML
@@ -38,8 +41,9 @@ struct rspamd_task;
  * @return 0 on success, -1 on error/no rewrite
  */
 int rspamd_html_url_rewrite(struct rspamd_task *task,
+                                                       struct lua_State *L,
                                                        void *html_content,
-                                                       const char *func_name,
+                                                       int func_ref,
                                                        int part_id,
                                                        const char *original_html,
                                                        gsize html_len,
index 23b890f9237fb8e8180d4a0ea2482133ec8ddefb..ca72dc7f702001dc1928ec90717078bbdf2ca956 100644 (file)
@@ -1277,10 +1277,10 @@ LUA_FUNCTION_DEF(task, get_dns_req);
 LUA_FUNCTION_DEF(task, add_timer);
 
 /***
- * @method task:rewrite_html_urls(func_name)
+ * @method task:rewrite_html_urls(callback)
  * Rewrites URLs in HTML parts using the specified Lua callback function.
  * The callback receives (task, url) and should return the replacement URL or nil.
- * @param {string} func_name name of Lua function to call for each URL
+ * @param {function} callback Lua function to call for each URL
  * @return {table|nil} table of rewritten HTML parts indexed by part number, or nil on error
  */
 LUA_FUNCTION_DEF(task, rewrite_html_urls);
@@ -7798,10 +7798,9 @@ static int
 lua_task_rewrite_html_urls(lua_State *L)
 {
        struct rspamd_task *task = lua_check_task(L, 1);
-       const char *func_name = luaL_checkstring(L, 2);
 
-       if (!func_name) {
-               return luaL_error(L, "invalid arguments: function name expected");
+       if (!lua_isfunction(L, 2)) {
+               return luaL_error(L, "invalid arguments: function expected");
        }
 
        if (!task || !MESSAGE_FIELD_CHECK(task, text_parts)) {
@@ -7809,6 +7808,10 @@ lua_task_rewrite_html_urls(lua_State *L)
                return 1;
        }
 
+       /* Create function reference */
+       lua_pushvalue(L, 2);
+       int func_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
        /* Create result table */
        lua_newtable(L);
        int results = 0;
@@ -7831,8 +7834,9 @@ lua_task_rewrite_html_urls(lua_State *L)
                /* Process URL rewriting using C wrapper */
                int ret = rspamd_html_url_rewrite(
                        task,
+                       L,
                        text_part->html,
-                       func_name,
+                       func_ref,
                        text_part->mime_part->part_number,
                        (const char *) text_part->parsed.begin,
                        text_part->parsed.len,
@@ -7855,6 +7859,9 @@ lua_task_rewrite_html_urls(lua_State *L)
                }
        }
 
+       /* Unreference the function */
+       luaL_unref(L, LUA_REGISTRYINDEX, func_ref);
+
        if (results == 0) {
                lua_pop(L, 1);
                lua_pushnil(L);