]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Refactor llm_search_context to use lua_cache module
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 5 Nov 2025 14:25:19 +0000 (14:25 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 5 Nov 2025 14:25:19 +0000 (14:25 +0000)
- Replace manual Redis operations with lua_cache API for better consistency
- Use messagepack serialization and automatic key hashing
- Fix Leta Mullvad API URL to /search/__data.json endpoint
- Add search_engine parameter support
- Remove redundant 'or DEFAULTS.xxx' patterns (opts already has defaults merged)
- Add proper debug_module propagation throughout call chain
- Improve JSON parsing to handle Leta Mullvad's nested pointer structure

conf/modules.d/gpt.conf
lualib/llm_search_context.lua
src/plugins/lua/gpt.lua

index 3e3d5ea6c2f5a7e5ae6b9fa5ee2dd95d45fc393f..ce1ae9e6485cda645af0ee97509a032fce0d3fe2 100644 (file)
@@ -59,7 +59,8 @@ gpt {
   # Extracts domains from email URLs and queries a search API for context
   search_context = {
     enabled = false; # Enable web search context
-    #search_url = "https://leta.mullvad.net/api/search"; # Search API endpoint
+    #search_url = "https://leta.mullvad.net/search/__data.json"; # Search API endpoint
+    #search_engine = "brave"; # Search engine (brave, google, etc.)
     #max_domains = 3; # Maximum domains to search
     #max_results_per_query = 3; # Maximum results per domain
     #timeout = 5; # HTTP timeout in seconds
index 036db6c3873abba4cfe603fbb11193db833e0d29..12c2cd364ae571843eb7a5de45f9c509f2fc2150 100644 (file)
@@ -43,14 +43,14 @@ local M = {}
 
 local rspamd_http = require "rspamd_http"
 local rspamd_logger = require "rspamd_logger"
-local rspamd_util = require "rspamd_util"
 local lua_util = require "lua_util"
-local lua_redis = require "lua_redis"
+local lua_cache = require "lua_cache"
 local ucl = require "ucl"
 
 local DEFAULTS = {
   enabled = false,
-  search_url = "https://leta.mullvad.net/api/search",
+  search_url = "https://leta.mullvad.net/search/__data.json",
+  search_engine = "brave", -- Search engine to use (brave, google, etc.)
   max_domains = 3,
   max_results_per_query = 3,
   timeout = 5,
@@ -98,24 +98,14 @@ local function extract_domains(task, max_domains)
   return domains
 end
 
--- Generate cache key for a domain
-local function get_cache_key(domain, opts)
-  local key_prefix = opts.cache_key_prefix or DEFAULTS.cache_key_prefix
-  local hash = rspamd_util.hash_create()
-  hash:update(domain)
-  return string.format("%s:%s", key_prefix, hash:hex())
-end
-
 -- Query search API for a single domain
-local function query_search_api(task, domain, opts, callback)
-  local url = opts.search_url or DEFAULTS.search_url
-  local timeout = opts.timeout or DEFAULTS.timeout
-  local max_results = opts.max_results_per_query or DEFAULTS.max_results_per_query
+local function query_search_api(task, domain, opts, callback, debug_module)
+  local Np = debug_module or N
 
-  -- Prepare search query
+  -- Prepare search query for Leta Mullvad API
   local query_params = {
     q = domain,
-    limit = tostring(max_results),
+    engine = opts.search_engine,
   }
 
   -- Build query string
@@ -124,43 +114,117 @@ local function query_search_api(task, domain, opts, callback)
     if query_string ~= "" then
       query_string = query_string .. "&"
     end
-    query_string = query_string .. k .. "=" .. rspamd_util.url_encode(v)
+    query_string = query_string .. k .. "=" .. lua_util.url_encode_string(v)
   end
 
-  local full_url = url .. "?" .. query_string
-
-  lua_util.debugm(N, task, "querying search API: %s", full_url)
+  local full_url = opts.search_url .. "?" .. query_string
 
   local function http_callback(err, code, body, _)
     if err then
-      lua_util.debugm(N, task, "search API error for %s: %s", domain, err)
+      lua_util.debugm(Np, task, "search API error for domain '%s': %s", domain, err)
       callback(nil, domain, err)
       return
     end
 
     if code ~= 200 then
-      lua_util.debugm(N, task, "search API returned code %s for %s", code, domain)
+      rspamd_logger.infox(task, "search API returned code %s for domain '%s', url: %s, body: %s",
+        code, domain, full_url, body and body:sub(1, 200) or 'nil')
       callback(nil, domain, string.format("HTTP %s", code))
       return
     end
 
-    -- Parse JSON response
+    lua_util.debugm(Np, task, "search API success for domain '%s', url: %s", domain, full_url)
+
+    -- Parse Leta Mullvad JSON response
     local parser = ucl.parser()
     local ok, parse_err = parser:parse_string(body)
     if not ok then
       rspamd_logger.errx(task, "%s: failed to parse search API response for %s: %s",
-        N, domain, parse_err)
+        Np, domain, parse_err)
       callback(nil, domain, parse_err)
       return
     end
 
-    local results = parser:get_object()
-    callback(results, domain, nil)
+    local data = parser:get_object()
+
+    -- Extract search results from Leta Mullvad's nested structure
+    -- Structure: data.nodes[3].data is a flat array with indices as pointers
+    -- data[1] = metadata with pointers, data[5] = items array (Lua 1-indexed)
+    local search_results = { results = {} }
+
+    if data and data.nodes and type(data.nodes) == 'table' and #data.nodes >= 3 then
+      local search_node = data.nodes[3]  -- Third node contains search data (Lua 1-indexed)
+
+      if search_node and search_node.data and type(search_node.data) == 'table' then
+        local flat_data = search_node.data
+        local metadata = flat_data[1]
+
+        lua_util.debugm(Np, task, "parsing domain '%s': flat_data has %d elements, metadata type: %s",
+          domain, #flat_data, type(metadata))
+
+        if metadata and metadata.items and type(metadata.items) == 'number' then
+          -- metadata.items is a 0-indexed pointer, add 1 for Lua
+          local items_idx = metadata.items + 1
+          local items = flat_data[items_idx]
+
+          if items and type(items) == 'table' then
+            lua_util.debugm(Np, task, "found %d item indices for domain '%s', items_idx=%d",
+              #items, domain, items_idx)
+
+            local count = 0
+
+            for _, result_idx in ipairs(items) do
+              if count >= opts.max_results_per_query then
+                break
+              end
+
+              -- result_idx is 0-indexed, add 1 for Lua
+              local result_template_idx = result_idx + 1
+              local result_template = flat_data[result_template_idx]
+
+              if result_template and type(result_template) == 'table' then
+                -- Extract values using the template's pointers (also 0-indexed)
+                local link = result_template.link and flat_data[result_template.link + 1]
+                local snippet = result_template.snippet and flat_data[result_template.snippet + 1]
+                local title = result_template.title and flat_data[result_template.title + 1]
+
+                lua_util.debugm(Np, task, "result %d template: link_idx=%s, snippet_idx=%s, title_idx=%s",
+                  count + 1, tostring(result_template.link), tostring(result_template.snippet),
+                  tostring(result_template.title))
+
+                if link or title or snippet then
+                  table.insert(search_results.results, {
+                    title = title or "",
+                    snippet = snippet or "",
+                    url = link or ""
+                  })
+                  count = count + 1
+                  lua_util.debugm(Np, task, "extracted result %d: title='%s', snippet_len=%d",
+                    count, title or "nil", snippet and #snippet or 0)
+                end
+              else
+                lua_util.debugm(Np, task, "result_template at idx %d is not a table: %s",
+                  result_template_idx, type(result_template))
+              end
+            end
+          else
+            lua_util.debugm(Np, task, "items is not a table for domain '%s', type: %s",
+              domain, type(items))
+          end
+        else
+          lua_util.debugm(Np, task, "no valid metadata.items for domain '%s'", domain)
+        end
+      end
+    end
+
+    lua_util.debugm(Np, task, "extracted %d search results for domain '%s'",
+      #search_results.results, domain)
+    callback(search_results, domain, nil)
   end
 
   rspamd_http.request({
     url = full_url,
-    timeout = timeout,
+    timeout = opts.timeout,
     callback = http_callback,
     task = task,
     log_obj = task,
@@ -185,7 +249,7 @@ local function format_search_results(all_results, opts)
       table.insert(context_lines, string.format("\nDomain: %s", domain))
 
       for i, result in ipairs(results.results) do
-        if i > (opts.max_results_per_query or DEFAULTS.max_results_per_query) then
+        if i > opts.max_results_per_query then
           break
         end
 
@@ -207,63 +271,6 @@ local function format_search_results(all_results, opts)
   return table.concat(context_lines, "\n")
 end
 
--- Check Redis cache for domain search results
-local function check_cache(task, redis_params, domain, opts, callback)
-  local cache_key = get_cache_key(domain, opts)
-
-  local function redis_callback(err, data)
-    if err then
-      lua_util.debugm(N, task, "Redis error for cache key %s: %s", cache_key, err)
-      callback(nil, domain)
-      return
-    end
-
-    if data and type(data) == 'string' then
-      -- Parse cached data
-      local parser = ucl.parser()
-      local ok, parse_err = parser:parse_string(data)
-      if ok then
-        lua_util.debugm(N, task, "cache hit for domain %s", domain)
-        callback(parser:get_object(), domain)
-      else
-        rspamd_logger.warnx(task, "%s: failed to parse cached data for %s: %s",
-          N, domain, parse_err)
-        callback(nil, domain)
-      end
-    else
-      lua_util.debugm(N, task, "cache miss for domain %s", domain)
-      callback(nil, domain)
-    end
-  end
-
-  lua_redis.redis_make_request(task, redis_params, cache_key, false,
-    redis_callback, 'GET', { cache_key })
-end
-
--- Store search results in Redis cache
-local function store_cache(task, redis_params, domain, results, opts)
-  local cache_key = get_cache_key(domain, opts)
-  local ttl = opts.cache_ttl or DEFAULTS.cache_ttl
-
-  if not results then
-    return
-  end
-
-  local data = ucl.to_format(results, 'json-compact')
-
-  local function redis_callback(err, _)
-    if err then
-      rspamd_logger.warnx(task, "%s: failed to cache results for %s: %s",
-        N, domain, err)
-    else
-      lua_util.debugm(N, task, "cached results for domain %s (TTL: %ss)", domain, ttl)
-    end
-  end
-
-  lua_redis.redis_make_request(task, redis_params, cache_key, true,
-    redis_callback, 'SETEX', { cache_key, tostring(ttl), data })
-end
-
 -- Main function to fetch and format search context
 function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
   local Np = debug_module or N
@@ -289,11 +296,23 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
   lua_util.debugm(Np, task, "extracted %s domain(s) for search: %s",
     #domains, table.concat(domains, ", "))
 
+  -- Create cache context
+  local cache_ctx = nil
+  if redis_params then
+    cache_ctx = lua_cache.create_cache_context(redis_params, {
+      cache_prefix = opts.cache_key_prefix,
+      cache_ttl = opts.cache_ttl,
+      cache_format = 'messagepack',
+      cache_hash_len = 16,
+      cache_use_hashing = true,
+    }, Np)
+  end
+
   local pending_queries = #domains
   local all_results = {}
 
-  -- Callback for each domain query
-  local function domain_callback(results, domain, err)
+  -- Callback for each domain query complete
+  local function domain_complete(domain, results)
     pending_queries = pending_queries - 1
 
     if results then
@@ -301,8 +320,6 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
         domain = domain,
         results = results
       })
-    elseif err then
-      lua_util.debugm(Np, task, "search failed for domain %s: %s", domain, err)
     end
 
     if pending_queries == 0 then
@@ -321,25 +338,42 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
 
   -- Process each domain
   for _, domain in ipairs(domains) do
-    if redis_params then
-      -- Check cache first
-      check_cache(task, redis_params, domain, opts, function(cached_results, dom)
-        if cached_results then
-          -- Use cached results
-          domain_callback(cached_results, dom, nil)
-        else
-          -- Query API and cache results (no retry, fail gracefully)
-          query_search_api(task, dom, opts, function(api_results, d, api_err)
-            if api_results and redis_params then
-              store_cache(task, redis_params, d, api_results, opts)
+    local cache_key = string.format("search:%s:%s", opts.search_engine, domain)
+
+    if cache_ctx then
+      -- Use lua_cache for caching
+      lua_cache.cache_get(task, cache_key, cache_ctx, opts.timeout,
+        function()
+          -- Cache miss - query API
+          query_search_api(task, domain, opts, function(api_results, d, api_err)
+            if api_results then
+              lua_cache.cache_set(task, cache_key, api_results, cache_ctx)
+              domain_complete(d, api_results)
+            else
+              lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
+              domain_complete(d, nil)
             end
-            domain_callback(api_results, d, api_err)
-          end)
-        end
-      end)
+          end, Np)
+        end,
+        function(_, err, data)
+          -- Cache hit or after miss callback
+          if data and type(data) == 'table' then
+            lua_util.debugm(Np, task, "cache hit for domain %s", domain)
+            domain_complete(domain, data)
+          -- If no data and no error, the miss callback was already invoked
+          elseif err then
+            lua_util.debugm(Np, task, "cache error for domain %s: %s", domain, err)
+            domain_complete(domain, nil)
+          end
+        end)
     else
-      -- No Redis, query directly (no retry, fail gracefully)
-      query_search_api(task, domain, opts, domain_callback)
+      -- No Redis, query directly
+      query_search_api(task, domain, opts, function(api_results, d, api_err)
+        if not api_results then
+          lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
+        end
+        domain_complete(d, api_results)
+      end, Np)
     end
   end
 end
index ff90d9e1897d7e02f3cb6be977e545853b171ca7..8255cffbc6d0e85b048e738bc1373e5a31111d32 100644 (file)
@@ -93,7 +93,8 @@ if confighelp then
   # Optional web search context (extract domains from URLs and search for context)
   search_context = {
     enabled = false; # fetch web search context for domains in email
-    search_url = "https://leta.mullvad.net/api/search"; # Search API endpoint
+    search_url = "https://leta.mullvad.net/search/__data.json"; # Search API endpoint
+    search_engine = "brave"; # Search engine (brave, google, etc.)
     max_domains = 3; # Maximum domains to search
     max_results_per_query = 3; # Maximum results per domain
     timeout = 5; # HTTP timeout in seconds
@@ -236,7 +237,8 @@ local settings = {
   -- Web search context options (for extracting and searching domains from URLs)
   search_context = {
     enabled = false,
-    search_url = 'https://leta.mullvad.net/api/search', -- Search API endpoint
+    search_url = 'https://leta.mullvad.net/search/__data.json', -- Search API endpoint
+    search_engine = 'brave',                            -- Search engine (brave, google, etc.)
     max_domains = 3,                                    -- Maximum domains to search
     max_results_per_query = 3,                          -- Maximum results per domain
     timeout = 5,                                        -- HTTP timeout in seconds