]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add task:get_cta_urls() API for proper CTA domain extraction
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 6 Nov 2025 10:54:58 +0000 (10:54 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 6 Nov 2025 10:54:58 +0000 (10:54 +0000)
- C code (message.c): collect top CTA URLs per HTML part by button weight,
  store in task mempool variable "html_cta_urls"
- Lua API (lua_task.c): add task:get_cta_urls([max_urls]) method
- llm_search_context: use new API instead of reimplementing CTA logic in Lua
- Benefits: single source of truth for CTA logic, uses C knowledge of HTML
  structure and button weights, cleaner Lua code

This provides proper architecture where C code handles HTML structure analysis
and Lua adds domain filtering (blacklists, infrastructure domains, etc.)

lualib/llm_search_context.lua
src/libmime/message.c
src/lua/lua_task.c

index 12c2cd364ae571843eb7a5de45f9c509f2fc2150..e75199b4fee79fcef63c28efa649a95163a04419 100644 (file)
@@ -61,36 +61,84 @@ local DEFAULTS = {
   disable_expression = nil,
 }
 
--- Extract unique domains from task URLs
-local function extract_domains(task, max_domains)
+-- Extract unique domains from task URLs, prioritizing CTA (call-to-action) links
+local function extract_domains(task, max_domains, debug_module)
+  local Np = debug_module or N
   local domains = {}
   local seen = {}
 
-  -- Get URLs from the task using extract_specific_urls
-  local urls = lua_util.extract_specific_urls({
-    task = task,
-    limit = max_domains * 3, -- Get more to filter
-    esld_limit = max_domains,
-  }) or {}
+  -- Skip common domains that won't provide useful context
+  local skip_domains = {
+    ['localhost'] = true,
+    ['127.0.0.1'] = true,
+    ['example.com'] = true,
+    ['example.org'] = true,
+  }
 
-  for _, url in ipairs(urls) do
-    if #domains >= max_domains then
-      break
-    end
+  -- First, try to get CTA URLs from HTML (most relevant for spam detection)
+  -- Uses button weight and HTML structure analysis from C code
+  local cta_urls = task:get_cta_urls(max_domains * 2) or {}
+  lua_util.debugm(Np, task, "CTA analysis found %d URLs", #cta_urls)
+
+  for _, url in ipairs(cta_urls) do
+    if #domains >= max_domains then break end
 
     local host = url:get_host()
-    if host and not seen[host] then
-      -- Skip common domains that won't provide useful context
-      local skip_domains = {
-        ['localhost'] = true,
-        ['127.0.0.1'] = true,
-        ['example.com'] = true,
-        ['example.org'] = true,
-      }
-
-      if not skip_domains[host:lower()] then
+    if host and not skip_domains[host:lower()] and not seen[host] then
+      seen[host] = true
+      table.insert(domains, host)
+      lua_util.debugm(Np, task, "added CTA domain: %s", host)
+    end
+  end
+
+  -- If we don't have enough domains from CTA, get more from content URLs
+  if #domains < max_domains then
+    lua_util.debugm(Np, task, "need more domains (%d/%d), extracting from content URLs",
+      #domains, max_domains)
+
+    local urls = lua_util.extract_specific_urls({
+      task = task,
+      limit = max_domains * 3,
+      esld_limit = max_domains,
+      need_content = true,      -- Content URLs (buttons, links in text)
+      need_images = false,
+    }) or {}
+
+    lua_util.debugm(Np, task, "extracted %d content URLs", #urls)
+
+    for _, url in ipairs(urls) do
+      if #domains >= max_domains then break end
+
+      local host = url:get_host()
+      if host and not seen[host] and not skip_domains[host:lower()] then
+        seen[host] = true
+        table.insert(domains, host)
+        lua_util.debugm(Np, task, "added content domain: %s", host)
+      end
+    end
+  end
+
+  -- Still need more? Get from any URLs
+  if #domains < max_domains then
+    lua_util.debugm(Np, task, "still need more domains (%d/%d), extracting from all URLs",
+      #domains, max_domains)
+
+    local urls = lua_util.extract_specific_urls({
+      task = task,
+      limit = max_domains * 3,
+      esld_limit = max_domains,
+    }) or {}
+
+    lua_util.debugm(Np, task, "extracted %d all URLs", #urls)
+
+    for _, url in ipairs(urls) do
+      if #domains >= max_domains then break end
+
+      local host = url:get_host()
+      if host and not seen[host] and not skip_domains[host:lower()] then
         seen[host] = true
         table.insert(domains, host)
+        lua_util.debugm(Np, task, "added general domain: %s", host)
       end
     end
   end
@@ -285,7 +333,7 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
   end
 
   -- Extract domains from task
-  local domains = extract_domains(task, opts.max_domains)
+  local domains = extract_domains(task, opts.max_domains, Np)
 
   if #domains == 0 then
     lua_util.debugm(Np, task, "no domains to search")
@@ -293,7 +341,7 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
     return
   end
 
-  lua_util.debugm(Np, task, "extracted %s domain(s) for search: %s",
+  lua_util.debugm(Np, task, "final domain list (%d domains) for search: %s",
     #domains, table.concat(domains, ", "))
 
   -- Create cache context
index 9eb9df3309506a3bd40be1dfab6818f1b0c314ec..84ea13711371aaea38023248c6b433ccc0fda6a3 100644 (file)
@@ -944,6 +944,55 @@ rspamd_message_process_html_text_part(struct rspamd_task *task,
 
                        lua_settop(L, old_top);
                }
+
+               /* Store top CTA URLs for LLM and other use cases */
+               if (text_part->html && text_part->mime_part && text_part->mime_part->urls) {
+                       /* Simple approach: just store URLs sorted by button weight */
+                       /* Use task-wide array to aggregate across all HTML parts */
+                       GPtrArray *cta_urls = rspamd_mempool_get_variable(task->task_pool, "html_cta_urls");
+                       if (!cta_urls) {
+                               cta_urls = g_ptr_array_new();
+                               rspamd_mempool_add_destructor(task->task_pool,
+                                                                                         (rspamd_mempool_destruct_t) rspamd_ptr_array_free_hard,
+                                                                                         cta_urls);
+                               rspamd_mempool_set_variable(task->task_pool, "html_cta_urls", cta_urls, NULL);
+                       }
+
+                       /* Find best URLs by button weight in this HTML part */
+                       float best_weights[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
+                       struct rspamd_url *best_urls[5] = {NULL, NULL, NULL, NULL, NULL};
+                       unsigned int max_cta_per_part = 5;
+
+                       for (unsigned int i = 0; i < text_part->mime_part->urls->len; i++) {
+                               struct rspamd_url *u = g_ptr_array_index(text_part->mime_part->urls, i);
+                               if (!u) continue;
+                               if (!(u->protocol == PROTOCOL_HTTP || u->protocol == PROTOCOL_HTTPS)) continue;
+                               if (u->flags & RSPAMD_URL_FLAG_INVISIBLE) continue;
+
+                               float weight = rspamd_html_url_button_weight(text_part->html, u);
+
+                               /* Insert into best list if weight is high enough */
+                               for (unsigned int j = 0; j < max_cta_per_part; j++) {
+                                       if (weight > best_weights[j]) {
+                                               /* Shift lower entries down */
+                                               for (unsigned int k = max_cta_per_part - 1; k > j; k--) {
+                                                       best_weights[k] = best_weights[k - 1];
+                                                       best_urls[k] = best_urls[k - 1];
+                                               }
+                                               best_weights[j] = weight;
+                                               best_urls[j] = u;
+                                               break;
+                                       }
+                               }
+                       }
+
+                       /* Add to task-wide array */
+                       for (unsigned int i = 0; i < max_cta_per_part; i++) {
+                               if (best_urls[i] && best_weights[i] > 0.0) {
+                                       g_ptr_array_add(cta_urls, best_urls[i]);
+                               }
+                       }
+               }
        }
        rspamd_html_get_parsed_content(text_part->html, &text_part->utf_content);
 
index e10c7e089ba6c8a2c696f6d8707e8ff658da74e5..b7e42530c3e1198b21a8cc2524456df2547664a8 100644 (file)
@@ -278,6 +278,14 @@ LUA_FUNCTION_DEF(task, get_urls);
  * @return {table rspamd_url} list of urls matching conditions
  */
 LUA_FUNCTION_DEF(task, get_urls_filtered);
+/***
+ * @method task:get_cta_urls([max_urls])
+ * Get call-to-action URLs from HTML content, prioritized by button weight
+ * These are URLs that users are likely to click (buttons, prominent links, etc.)
+ * @param {number} max_urls maximum number of URLs to return (default: all)
+ * @return {table rspamd_url} list of CTA urls sorted by importance
+ */
+LUA_FUNCTION_DEF(task, get_cta_urls);
 /***
  * @method task:has_urls([need_emails])
  * Returns 'true' if a task has urls listed
@@ -1325,6 +1333,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF(task, has_urls),
        LUA_INTERFACE_DEF(task, get_urls),
        LUA_INTERFACE_DEF(task, get_urls_filtered),
+       LUA_INTERFACE_DEF(task, get_cta_urls),
        LUA_INTERFACE_DEF(task, inject_url),
        LUA_INTERFACE_DEF(task, get_content),
        LUA_INTERFACE_DEF(task, get_filename),
@@ -2724,6 +2733,61 @@ lua_task_get_urls_filtered(lua_State *L)
        return 1;
 }
 
+static int
+lua_task_get_cta_urls(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_task *task = lua_check_task(L, 1);
+       GPtrArray *cta_urls;
+       unsigned int max_urls = 0;
+       unsigned int nret = 0;
+
+       if (task == NULL) {
+               return luaL_error(L, "invalid arguments, no task");
+       }
+
+       if (task->message == NULL) {
+               lua_newtable(L);
+               return 1;
+       }
+
+       /* Get optional max_urls parameter */
+       if (lua_gettop(L) >= 2 && lua_isnumber(L, 2)) {
+               max_urls = lua_tointeger(L, 2);
+       }
+
+       /* Retrieve CTA URLs from mempool */
+       cta_urls = rspamd_mempool_get_variable(task->task_pool, "html_cta_urls");
+
+       if (cta_urls == NULL || cta_urls->len == 0) {
+               lua_newtable(L);
+               return 1;
+       }
+
+       /* Create result table */
+       unsigned int result_size = max_urls > 0 ? MIN(max_urls, cta_urls->len) : cta_urls->len;
+       lua_createtable(L, result_size, 0);
+
+       /* Add URLs to result */
+       for (unsigned int i = 0; i < cta_urls->len; i++) {
+               struct rspamd_url *u = g_ptr_array_index(cta_urls, i);
+               if (u) {
+                       struct rspamd_lua_url *lua_url;
+
+                       lua_url = lua_newuserdata(L, sizeof(struct rspamd_lua_url));
+                       rspamd_lua_setclass(L, rspamd_url_classname, -1);
+                       lua_url->url = u;
+                       lua_rawseti(L, -2, ++nret);
+
+                       if (max_urls > 0 && nret >= max_urls) {
+                               break;
+                       }
+               }
+       }
+
+       return 1;
+}
+
 static int
 lua_task_has_urls(lua_State *L)
 {