]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] GPT: Fix occasional damage
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 11:30:54 +0000 (12:30 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 11:30:54 +0000 (12:30 +0100)
lualib/llm_common.lua
src/plugins/lua/gpt.lua

index 92d9a70d5501c560b00e5cfa61aa1335d0df4c69..a89aafa438ae65aa6c66a90dd6f266957b57b3a3 100644 (file)
@@ -25,7 +25,8 @@ local function get_meta_llm_content(task)
   return url_content, from_content
 end
 
--- Build a single text payload suitable for LLM embeddings
+-- Build structured payload suitable for LLM embeddings and chat
+-- Returns: table { subject = <string>, from = <string>, url_domains = <string>, text = <rspamd_text|string> }, part
 function M.build_llm_input(task, opts)
   opts = opts or {}
   local subject = task:get_subject() or ''
@@ -42,26 +43,25 @@ function M.build_llm_input(task, opts)
   end
 
   local max_tokens = tonumber(opts.max_tokens) or 1024
-  local text_line
+  local text
   if nwords > max_tokens then
     local words = sel_part:get_words('norm') or {}
     if #words > max_tokens then
-      text_line = table.concat(words, ' ', 1, max_tokens)
+      text = table.concat(words, ' ', 1, max_tokens)
     else
-      text_line = table.concat(words, ' ')
+      text = table.concat(words, ' ')
     end
   else
-    text_line = sel_part:get_content_oneline() or ''
+    -- Keep rspamd_text (userdata) intact; consumers (http/ucl) can use it directly
+    text = sel_part:get_content_oneline() or ''
   end
 
-  local content = table.concat({
-    'Subject: ' .. subject,
-    from_content,
-    url_content,
-    text_line,
-  }, '\n')
-
-  return content, sel_part
+  return {
+    subject = subject,
+    from = from_content,
+    url_domains = url_content,
+    text = text,
+  }, sel_part
 end
 
 -- Backwards-compat alias
index 8c533ec647f1b731f86bbe0cf7e13a72e3ea04fd..1790e5e8d9c7598fa1866188cf3f4578fe4352f2 100644 (file)
@@ -211,18 +211,18 @@ local function default_condition(task)
   end
 
   -- Unified LLM input building (subject/from/urls/body one-line)
-  local content, sel_part = llm_common.build_llm_input(task, { max_tokens = settings.max_tokens })
+  local input_tbl, sel_part = llm_common.build_llm_input(task, { max_tokens = settings.max_tokens })
   if not sel_part then
     return false, 'no text part found'
   end
-  if not content or #content == 0 then
+  if not input_tbl then
     local nwords = sel_part:get_words_count() or 0
     if nwords < 5 then
       return false, 'less than 5 words'
     end
     return false, 'no content to send'
   end
-  return true, content, sel_part
+  return true, input_tbl, sel_part
 end
 
 local function maybe_extract_json(str)
@@ -638,12 +638,11 @@ local function openai_check(task, content, sel_part)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
 
   local upstream
-
   local results = {}
 
-  local function gen_reply_closure(model, idx)
+  local function gen_reply_closure(model, i)
     return function(err, code, body)
-      results[idx].checked = true
+      results[i].checked = true
       if err then
         rspamd_logger.errx(task, '%s: request failed: %s', model, err)
         upstream:fail()
@@ -658,34 +657,46 @@ local function openai_check(task, content, sel_part)
         return
       end
 
-      local reply, reason, categories = settings.reply_conversion(task, body)
+      local reply, reason = settings.reply_conversion(task, body)
 
-      results[idx].model = model
+      results[i].model = model
 
       if reply then
-        results[idx].success = true
-        results[idx].probability = reply
-        results[idx].reason = reason
-
-        if categories then
-          results[idx].categories = categories
-        end
+        results[i].success = true
+        results[i].probability = reply
+        results[i].reason = reason
       end
 
       check_consensus_and_insert_results(task, results, sel_part)
     end
   end
 
+  -- Build messages exactly as in the original code if structured table provided
+  local user_messages
+  if type(content) == 'table' then
+    local subject_line = 'Subject: ' .. (content.subject or '')
+    user_messages = {
+      { role = 'user', content = subject_line },
+      { role = 'user', content = content.from or '' },
+      { role = 'user', content = content.url_domains or '' },
+      { role = 'user', content = content.text or '' },
+    }
+  else
+    user_messages = {
+      { role = 'user', content = content }
+    }
+  end
+
   local body_base = {
+    stream = false,
+    max_tokens = settings.max_tokens,
+    temperature = settings.temperature,
     messages = {
       {
         role = 'system',
         content = settings.prompt
       },
-      {
-        role = 'user',
-        content = content
-      }
+      lua_util.unpack(user_messages)
     }
   }
 
@@ -776,6 +787,21 @@ local function ollama_check(task, content, sel_part)
     end
   end
 
+  local user_messages
+  if type(content) == 'table' then
+    local subject_line = 'Subject: ' .. (content.subject or '')
+    user_messages = {
+      { role = 'user', content = subject_line },
+      { role = 'user', content = content.from or '' },
+      { role = 'user', content = content.url_domains or '' },
+      { role = 'user', content = content.text or '' },
+    }
+  else
+    user_messages = {
+      { role = 'user', content = content }
+    }
+  end
+
   if type(settings.model) == 'string' then
     settings.model = { settings.model }
   end
@@ -790,10 +816,7 @@ local function ollama_check(task, content, sel_part)
         role = 'system',
         content = settings.prompt
       },
-      {
-        role = 'user',
-        content = content
-      }
+      table.unpack(user_messages)
     }
   }