]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Improve GPT module with uncertain caching and server timeout
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Oct 2025 13:32:22 +0000 (14:32 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Oct 2025 13:32:22 +0000 (14:32 +0100)
* Add GPT_UNCERTAIN symbol for caching uncertain classifications
  - Cache results even when no consensus is reached
  - Avoid repeated expensive LLM queries for borderline cases
  - Set X-GPT-Reason header with detailed vote statistics
* Add server-side timeout support for OpenAI API requests
  - New request_timeout parameter (optional, multiplied by 0.95)
  - Only sent if explicitly configured (not all APIs support this)
  - Accounts for connection setup and data transfer overhead
* Fix max_ham_prob initialization (was 0, now correctly 1.0)
* Add pcall protection for fold_header_with_encoding with raw fallback
* Improve error messages for token limit exceeded
* Add detailed logging for context snippets and consensus decisions
* Pass debug_module parameter to llm_context functions

src/plugins/lua/gpt.lua

index 394923ed7cc8e0fb05e98556d4eb0fbad5ef4730..10526361f2603238ac3d4ab3edaffc877dd744ca 100644 (file)
@@ -62,6 +62,13 @@ if confighelp then
   reason_header = "X-GPT-Reason";
   # Use JSON format for response
   json = false;
+  # Optional: pass request timeout to the server (in seconds)
+  # WARNING: Not all API implementations support this parameter (e.g., standard OpenAI API doesn't)
+  # Only enable if your API endpoint/proxy specifically supports max_completion_time parameter
+  # If not set, this parameter will not be sent to the server
+  # Note: the actual value sent to server is multiplied by 0.95 to account for
+  # connection setup, SSL handshake, and data transfer overhead
+  # request_timeout = 8;
 
   # Optional user/domain context in Redis
   context = {
@@ -133,6 +140,11 @@ local default_extra_symbols = {
     description = 'GPT model detected malware content',
     category = 'malware',
   },
+  GPT_UNCERTAIN = {
+    score = 0.0,
+    description = 'GPT model was uncertain about classification',
+    category = 'uncertain',
+  },
 }
 
 -- Should be filled from extra symbols
@@ -172,6 +184,7 @@ local settings = {
   json = false,
   extra_symbols = nil,
   cache_prefix = REDIS_PREFIX,
+  request_timeout = nil, -- Optional: pass request timeout to server (in seconds)
   -- user/domain context options (nested table forwarded to llm_context)
   context = {
     enabled = false,
@@ -432,12 +445,12 @@ local function default_openai_json_conversion(task, input)
       elseif reply.probability == "low" then
         spam_score = 0.1
       else
-        rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
+        lua_util.debugm(N, task, "cannot convert to spam probability: %s", reply.probability)
       end
     end
 
     if type(reply.usage) == 'table' then
-      rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+      lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
     end
 
     return spam_score, reply.reason, {}
@@ -475,9 +488,22 @@ local function default_openai_plain_conversion(task, input)
   end
 
   local first_message = reply.choices[1].message.content
+  local finish_reason = reply.choices[1].finish_reason or 'unknown'
 
   if not first_message or first_message == "" then
-    rspamd_logger.errx(task, 'no content in the first message')
+    if finish_reason == 'length' then
+      -- Token limit exceeded - provide helpful error message
+      local usage = reply.usage or {}
+      local completion_tokens = usage.completion_tokens or 0
+      local reasoning_tokens = usage.completion_tokens_details and usage.completion_tokens_details.reasoning_tokens or 0
+      rspamd_logger.errx(task, 'LLM response truncated: token limit exceeded. ' ..
+        'Used %s completion tokens (including %s reasoning tokens). ' ..
+        'Increase max_completion_tokens in model_parameters config for this model.',
+        completion_tokens, reasoning_tokens)
+    else
+      rspamd_logger.errx(task, 'no content in the first message (finish_reason: %s, usage: %s)',
+        finish_reason, reply.usage and ucl.to_format(reply.usage, 'json-compact') or 'none')
+    end
     return
   end
 
@@ -491,7 +517,7 @@ local function default_openai_plain_conversion(task, input)
   local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
 
   if type(reply.usage) == 'table' then
-    rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+    lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
   end
 
   if spam_score then
@@ -592,12 +618,12 @@ local function default_ollama_json_conversion(task, input)
       elseif reply.probability == "low" then
         spam_score = 0.1
       else
-        rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
+        lua_util.debugm(N, task, "cannot convert to spam probability: %s", reply.probability)
       end
     end
 
     if type(reply.usage) == 'table' then
-      rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+      lua_util.debugm(N, task, 'usage: %s tokens', reply.usage.total_tokens)
     end
 
     return spam_score, reply.reason, {}
@@ -647,7 +673,7 @@ local function insert_results(task, result, sel_part)
     if result.categories then
       process_categories(task, result.categories)
     end
-  else
+  elseif result.probability < 0.5 then
     task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability))
     if settings.autolearn then
       task:set_flag("learn_ham")
@@ -655,12 +681,27 @@ local function insert_results(task, result, sel_part)
     if result.categories then
       process_categories(task, result.categories)
     end
+  else
+    -- probability == 0.5, uncertain result, don't set GPT_SPAM/GPT_HAM
+    if result.categories then
+      process_categories(task, result.categories)
+    end
   end
   if result.reason and settings.reason_header then
-    local v = lua_util.fold_header_with_encoding(task, settings.reason_header,
-      tostring(result.reason), { encode = 'auto' })
-    lua_mime.modify_headers(task,
-      { add = { [settings.reason_header] = { value = v, order = 1 } } })
+    if type(settings.reason_header) == 'string' and #result.reason > 0 then
+      local ok, v = pcall(lua_util.fold_header_with_encoding, task, settings.reason_header,
+        result.reason, { encode = false, structured = false })
+      if ok and v then
+        lua_mime.modify_headers(task,
+          { add = { [settings.reason_header] = { value = v, order = 1 } } })
+      else
+        rspamd_logger.warnx(task, 'cannot fold header %s: %s; using raw value', settings.reason_header,
+          v)
+        -- Fallback: use raw value without encoding
+        lua_mime.modify_headers(task,
+          { add = { [settings.reason_header] = { value = result.reason, order = 1 } } })
+      end
+    end
   end
 
   if cache_context then
@@ -669,7 +710,7 @@ local function insert_results(task, result, sel_part)
 
   -- Update long-term user/domain context after classification
   if redis_params and settings.context then
-    llm_context.update_after_classification(task, redis_params, settings.context, result, sel_part)
+    llm_context.update_after_classification(task, redis_params, settings.context, result, sel_part, N)
   end
 end
 
@@ -681,21 +722,21 @@ local function check_consensus_and_insert_results(task, results, sel_part)
   end
 
   local nspam, nham = 0, 0
-  local max_spam_prob, max_ham_prob = 0, 0
+  local max_spam_prob, max_ham_prob = 0, 1.0
   local reasons = {}
 
   for _, result in ipairs(results) do
-    if result.success then
+    if result.success and result.probability then
       if result.probability > 0.5 then
         nspam = nspam + 1
         max_spam_prob = math.max(max_spam_prob, result.probability)
         lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'",
-          result.model, result.probability, result.reason)
+          result.model or 'unknown', result.probability, result.reason or 'no reason')
       else
         nham = nham + 1
         max_ham_prob = math.min(max_ham_prob, result.probability)
         lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'",
-          result.model, result.probability, result.reason)
+          result.model or 'unknown', result.probability, result.reason or 'no reason')
       end
 
       if result.reason then
@@ -724,8 +765,20 @@ local function check_consensus_and_insert_results(task, results, sel_part)
       },
       sel_part)
   else
-    -- No consensus
-    lua_util.debugm(N, task, "no consensus")
+    -- No consensus - still cache and set uncertain symbol to avoid re-querying LLM
+    lua_util.debugm(N, task, "no consensus: nspam=%s, nham=%s, max_spam_prob=%s, max_ham_prob=%s",
+      nspam, nham, max_spam_prob, max_ham_prob)
+    -- Use 0.5 (neutral) probability with uncertain marker
+    local uncertain_reason = reason_text or string.format(
+      "Uncertain classification: spam votes=%d (max %.2f), ham votes=%d (min %.2f)",
+      nspam, max_spam_prob, nham, max_ham_prob)
+    insert_results(task, {
+        probability = 0.5,
+        reason = uncertain_reason,
+        categories = { 'uncertain' },
+      },
+      sel_part)
+    task:insert_result('GPT_UNCERTAIN', 1.0)
   end
 end
 
@@ -747,7 +800,7 @@ local function check_llm_cached(task, content, sel_part, context_snippet)
     end
 
     if data then
-      rspamd_logger.infox(task, 'found cached response %s', cache_key)
+      lua_util.debugm(N, task, 'found cached response %s', cache_key)
       insert_results(task, data, sel_part)
     else
       check_llm_uncached(task, content, sel_part, context_snippet)
@@ -757,6 +810,11 @@ end
 
 local function openai_check(task, content, sel_part, context_snippet)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
+  if context_snippet then
+    lua_util.debugm(N, task, "with context snippet (%s chars): %s", #context_snippet, context_snippet)
+  else
+    lua_util.debugm(N, task, "no context snippet")
+  end
 
   local upstream
   local results = {}
@@ -851,6 +909,13 @@ local function openai_check(task, content, sel_part, context_snippet)
       body.response_format = { type = "json_object" }
     end
 
+    -- Optionally add request timeout for server-side timeout control
+    -- Only pass if explicitly configured (not all API implementations support this)
+    -- Multiply by 0.95 to account for connection setup, SSL handshake, and data transfer time
+    if settings.request_timeout then
+      body.max_completion_time = settings.request_timeout * 0.95
+    end
+
     body.model = model
 
     upstream = settings.upstreams:get_upstream_round_robin()
@@ -883,6 +948,11 @@ end
 
 local function ollama_check(task, content, sel_part, context_snippet)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
+  if context_snippet then
+    lua_util.debugm(N, task, "with context snippet (%s chars): %s", #context_snippet, context_snippet)
+  else
+    lua_util.debugm(N, task, "no context snippet")
+  end
 
   local upstream
   local results = {}
@@ -975,6 +1045,13 @@ local function ollama_check(task, content, sel_part, context_snippet)
       body.response_format = { type = "json_object" }
     end
 
+    -- Optionally add request timeout for server-side timeout control
+    -- Only pass if explicitly configured (not all API implementations support this)
+    -- Multiply by 0.95 to account for connection setup, SSL handshake, and data transfer time
+    if settings.request_timeout then
+      body.max_completion_time = settings.request_timeout * 0.95
+    end
+
     body.model = model
 
     upstream = settings.upstreams:get_upstream_round_robin()
@@ -1024,14 +1101,14 @@ local function gpt_check(task)
           inferred_result = { probability = 0.1, reason = 'ham by filters', categories = {} }
         end
       end
-      llm_context.update_after_classification(task, redis_params, settings.context, inferred_result, sel_part)
+      llm_context.update_after_classification(task, redis_params, settings.context, inferred_result, sel_part, N)
     end
-    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s; context updated", content)
+    lua_util.debugm(N, task, "skip checking gpt as the condition is not met: %s; context updated", content)
     return
   end
 
   if not ret then
-    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+    lua_util.debugm(N, task, "skip checking gpt as the condition is not met: %s", content)
     return
   end
 
@@ -1052,7 +1129,7 @@ local function gpt_check(task)
   if context_enabled then
     llm_context.fetch(task, redis_params, settings.context, function(_, _, snippet)
       proceed(snippet)
-    end)
+    end, N)
   else
     proceed(nil)
   end