From: Vsevolod Stakhov Date: Wed, 5 Nov 2025 14:25:19 +0000 (+0000) Subject: [Fix] Refactor llm_search_context to use lua_cache module X-Git-Tag: 3.14.0~12^2~7 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=97082fbe01fd43f842a8a3a071d81007e44d6bd4;p=thirdparty%2Frspamd.git [Fix] Refactor llm_search_context to use lua_cache module - Replace manual Redis operations with lua_cache API for better consistency - Use messagepack serialization and automatic key hashing - Fix Leta Mullvad API URL to /search/__data.json endpoint - Add search_engine parameter support - Remove redundant 'or DEFAULTS.xxx' patterns (opts already has defaults merged) - Add proper debug_module propagation throughout call chain - Improve JSON parsing to handle Leta Mullvad's nested pointer structure --- diff --git a/conf/modules.d/gpt.conf b/conf/modules.d/gpt.conf index 3e3d5ea6c2..ce1ae9e648 100644 --- a/conf/modules.d/gpt.conf +++ b/conf/modules.d/gpt.conf @@ -59,7 +59,8 @@ gpt { # Extracts domains from email URLs and queries a search API for context search_context = { enabled = false; # Enable web search context - #search_url = "https://leta.mullvad.net/api/search"; # Search API endpoint + #search_url = "https://leta.mullvad.net/search/__data.json"; # Search API endpoint + #search_engine = "brave"; # Search engine (brave, google, etc.) #max_domains = 3; # Maximum domains to search #max_results_per_query = 3; # Maximum results per domain #timeout = 5; # HTTP timeout in seconds diff --git a/lualib/llm_search_context.lua b/lualib/llm_search_context.lua index 036db6c387..12c2cd364a 100644 --- a/lualib/llm_search_context.lua +++ b/lualib/llm_search_context.lua @@ -43,14 +43,14 @@ local M = {} local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" -local rspamd_util = require "rspamd_util" local lua_util = require "lua_util" -local lua_redis = require "lua_redis" +local lua_cache = require "lua_cache" local ucl = require "ucl" local DEFAULTS = { enabled = false, - search_url = "https://leta.mullvad.net/api/search", + search_url = "https://leta.mullvad.net/search/__data.json", + search_engine = "brave", -- Search engine to use (brave, google, etc.) max_domains = 3, max_results_per_query = 3, timeout = 5, @@ -98,24 +98,14 @@ local function extract_domains(task, max_domains) return domains end --- Generate cache key for a domain -local function get_cache_key(domain, opts) - local key_prefix = opts.cache_key_prefix or DEFAULTS.cache_key_prefix - local hash = rspamd_util.hash_create() - hash:update(domain) - return string.format("%s:%s", key_prefix, hash:hex()) -end - -- Query search API for a single domain -local function query_search_api(task, domain, opts, callback) - local url = opts.search_url or DEFAULTS.search_url - local timeout = opts.timeout or DEFAULTS.timeout - local max_results = opts.max_results_per_query or DEFAULTS.max_results_per_query +local function query_search_api(task, domain, opts, callback, debug_module) + local Np = debug_module or N - -- Prepare search query + -- Prepare search query for Leta Mullvad API local query_params = { q = domain, - limit = tostring(max_results), + engine = opts.search_engine, } -- Build query string @@ -124,43 +114,117 @@ local function query_search_api(task, domain, opts, callback) if query_string ~= "" then query_string = query_string .. "&" end - query_string = query_string .. k .. "=" .. rspamd_util.url_encode(v) + query_string = query_string .. k .. "=" .. lua_util.url_encode_string(v) end - local full_url = url .. "?" .. query_string - - lua_util.debugm(N, task, "querying search API: %s", full_url) + local full_url = opts.search_url .. "?" .. query_string local function http_callback(err, code, body, _) if err then - lua_util.debugm(N, task, "search API error for %s: %s", domain, err) + lua_util.debugm(Np, task, "search API error for domain '%s': %s", domain, err) callback(nil, domain, err) return end if code ~= 200 then - lua_util.debugm(N, task, "search API returned code %s for %s", code, domain) + rspamd_logger.infox(task, "search API returned code %s for domain '%s', url: %s, body: %s", + code, domain, full_url, body and body:sub(1, 200) or 'nil') callback(nil, domain, string.format("HTTP %s", code)) return end - -- Parse JSON response + lua_util.debugm(Np, task, "search API success for domain '%s', url: %s", domain, full_url) + + -- Parse Leta Mullvad JSON response local parser = ucl.parser() local ok, parse_err = parser:parse_string(body) if not ok then rspamd_logger.errx(task, "%s: failed to parse search API response for %s: %s", - N, domain, parse_err) + Np, domain, parse_err) callback(nil, domain, parse_err) return end - local results = parser:get_object() - callback(results, domain, nil) + local data = parser:get_object() + + -- Extract search results from Leta Mullvad's nested structure + -- Structure: data.nodes[3].data is a flat array with indices as pointers + -- data[1] = metadata with pointers, data[5] = items array (Lua 1-indexed) + local search_results = { results = {} } + + if data and data.nodes and type(data.nodes) == 'table' and #data.nodes >= 3 then + local search_node = data.nodes[3] -- Third node contains search data (Lua 1-indexed) + + if search_node and search_node.data and type(search_node.data) == 'table' then + local flat_data = search_node.data + local metadata = flat_data[1] + + lua_util.debugm(Np, task, "parsing domain '%s': flat_data has %d elements, metadata type: %s", + domain, #flat_data, type(metadata)) + + if metadata and metadata.items and type(metadata.items) == 'number' then + -- metadata.items is a 0-indexed pointer, add 1 for Lua + local items_idx = metadata.items + 1 + local items = flat_data[items_idx] + + if items and type(items) == 'table' then + lua_util.debugm(Np, task, "found %d item indices for domain '%s', items_idx=%d", + #items, domain, items_idx) + + local count = 0 + + for _, result_idx in ipairs(items) do + if count >= opts.max_results_per_query then + break + end + + -- result_idx is 0-indexed, add 1 for Lua + local result_template_idx = result_idx + 1 + local result_template = flat_data[result_template_idx] + + if result_template and type(result_template) == 'table' then + -- Extract values using the template's pointers (also 0-indexed) + local link = result_template.link and flat_data[result_template.link + 1] + local snippet = result_template.snippet and flat_data[result_template.snippet + 1] + local title = result_template.title and flat_data[result_template.title + 1] + + lua_util.debugm(Np, task, "result %d template: link_idx=%s, snippet_idx=%s, title_idx=%s", + count + 1, tostring(result_template.link), tostring(result_template.snippet), + tostring(result_template.title)) + + if link or title or snippet then + table.insert(search_results.results, { + title = title or "", + snippet = snippet or "", + url = link or "" + }) + count = count + 1 + lua_util.debugm(Np, task, "extracted result %d: title='%s', snippet_len=%d", + count, title or "nil", snippet and #snippet or 0) + end + else + lua_util.debugm(Np, task, "result_template at idx %d is not a table: %s", + result_template_idx, type(result_template)) + end + end + else + lua_util.debugm(Np, task, "items is not a table for domain '%s', type: %s", + domain, type(items)) + end + else + lua_util.debugm(Np, task, "no valid metadata.items for domain '%s'", domain) + end + end + end + + lua_util.debugm(Np, task, "extracted %d search results for domain '%s'", + #search_results.results, domain) + callback(search_results, domain, nil) end rspamd_http.request({ url = full_url, - timeout = timeout, + timeout = opts.timeout, callback = http_callback, task = task, log_obj = task, @@ -185,7 +249,7 @@ local function format_search_results(all_results, opts) table.insert(context_lines, string.format("\nDomain: %s", domain)) for i, result in ipairs(results.results) do - if i > (opts.max_results_per_query or DEFAULTS.max_results_per_query) then + if i > opts.max_results_per_query then break end @@ -207,63 +271,6 @@ local function format_search_results(all_results, opts) return table.concat(context_lines, "\n") end --- Check Redis cache for domain search results -local function check_cache(task, redis_params, domain, opts, callback) - local cache_key = get_cache_key(domain, opts) - - local function redis_callback(err, data) - if err then - lua_util.debugm(N, task, "Redis error for cache key %s: %s", cache_key, err) - callback(nil, domain) - return - end - - if data and type(data) == 'string' then - -- Parse cached data - local parser = ucl.parser() - local ok, parse_err = parser:parse_string(data) - if ok then - lua_util.debugm(N, task, "cache hit for domain %s", domain) - callback(parser:get_object(), domain) - else - rspamd_logger.warnx(task, "%s: failed to parse cached data for %s: %s", - N, domain, parse_err) - callback(nil, domain) - end - else - lua_util.debugm(N, task, "cache miss for domain %s", domain) - callback(nil, domain) - end - end - - lua_redis.redis_make_request(task, redis_params, cache_key, false, - redis_callback, 'GET', { cache_key }) -end - --- Store search results in Redis cache -local function store_cache(task, redis_params, domain, results, opts) - local cache_key = get_cache_key(domain, opts) - local ttl = opts.cache_ttl or DEFAULTS.cache_ttl - - if not results then - return - end - - local data = ucl.to_format(results, 'json-compact') - - local function redis_callback(err, _) - if err then - rspamd_logger.warnx(task, "%s: failed to cache results for %s: %s", - N, domain, err) - else - lua_util.debugm(N, task, "cached results for domain %s (TTL: %ss)", domain, ttl) - end - end - - lua_redis.redis_make_request(task, redis_params, cache_key, true, - redis_callback, 'SETEX', { cache_key, tostring(ttl), data }) -end - -- Main function to fetch and format search context function M.fetch_and_format(task, redis_params, opts, callback, debug_module) local Np = debug_module or N @@ -289,11 +296,23 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module) lua_util.debugm(Np, task, "extracted %s domain(s) for search: %s", #domains, table.concat(domains, ", ")) + -- Create cache context + local cache_ctx = nil + if redis_params then + cache_ctx = lua_cache.create_cache_context(redis_params, { + cache_prefix = opts.cache_key_prefix, + cache_ttl = opts.cache_ttl, + cache_format = 'messagepack', + cache_hash_len = 16, + cache_use_hashing = true, + }, Np) + end + local pending_queries = #domains local all_results = {} - -- Callback for each domain query - local function domain_callback(results, domain, err) + -- Callback for each domain query complete + local function domain_complete(domain, results) pending_queries = pending_queries - 1 if results then @@ -301,8 +320,6 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module) domain = domain, results = results }) - elseif err then - lua_util.debugm(Np, task, "search failed for domain %s: %s", domain, err) end if pending_queries == 0 then @@ -321,25 +338,42 @@ function M.fetch_and_format(task, redis_params, opts, callback, debug_module) -- Process each domain for _, domain in ipairs(domains) do - if redis_params then - -- Check cache first - check_cache(task, redis_params, domain, opts, function(cached_results, dom) - if cached_results then - -- Use cached results - domain_callback(cached_results, dom, nil) - else - -- Query API and cache results (no retry, fail gracefully) - query_search_api(task, dom, opts, function(api_results, d, api_err) - if api_results and redis_params then - store_cache(task, redis_params, d, api_results, opts) + local cache_key = string.format("search:%s:%s", opts.search_engine, domain) + + if cache_ctx then + -- Use lua_cache for caching + lua_cache.cache_get(task, cache_key, cache_ctx, opts.timeout, + function() + -- Cache miss - query API + query_search_api(task, domain, opts, function(api_results, d, api_err) + if api_results then + lua_cache.cache_set(task, cache_key, api_results, cache_ctx) + domain_complete(d, api_results) + else + lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err) + domain_complete(d, nil) end - domain_callback(api_results, d, api_err) - end) - end - end) + end, Np) + end, + function(_, err, data) + -- Cache hit or after miss callback + if data and type(data) == 'table' then + lua_util.debugm(Np, task, "cache hit for domain %s", domain) + domain_complete(domain, data) + -- If no data and no error, the miss callback was already invoked + elseif err then + lua_util.debugm(Np, task, "cache error for domain %s: %s", domain, err) + domain_complete(domain, nil) + end + end) else - -- No Redis, query directly (no retry, fail gracefully) - query_search_api(task, domain, opts, domain_callback) + -- No Redis, query directly + query_search_api(task, domain, opts, function(api_results, d, api_err) + if not api_results then + lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err) + end + domain_complete(d, api_results) + end, Np) end end end diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index ff90d9e189..8255cffbc6 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -93,7 +93,8 @@ if confighelp then # Optional web search context (extract domains from URLs and search for context) search_context = { enabled = false; # fetch web search context for domains in email - search_url = "https://leta.mullvad.net/api/search"; # Search API endpoint + search_url = "https://leta.mullvad.net/search/__data.json"; # Search API endpoint + search_engine = "brave"; # Search engine (brave, google, etc.) max_domains = 3; # Maximum domains to search max_results_per_query = 3; # Maximum results per domain timeout = 5; # HTTP timeout in seconds @@ -236,7 +237,8 @@ local settings = { -- Web search context options (for extracting and searching domains from URLs) search_context = { enabled = false, - search_url = 'https://leta.mullvad.net/api/search', -- Search API endpoint + search_url = 'https://leta.mullvad.net/search/__data.json', -- Search API endpoint + search_engine = 'brave', -- Search engine (brave, google, etc.) max_domains = 3, -- Maximum domains to search max_results_per_query = 3, -- Maximum results per domain timeout = 5, -- HTTP timeout in seconds