]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Improve prompt and use plaintext instead of JSON
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 11:02:23 +0000 (11:02 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 11:02:23 +0000 (11:02 +0000)
src/plugins/lua/gpt.lua

index 0fb6123e1d62d1128c3b478ca58a518777cbf545..270d0fdfc5d45ecb0c6c044c9221db356d79da52 100644 (file)
@@ -50,6 +50,8 @@ gpt {
   allow_ham = false;
   # Add header with reason (null to disable)
   reason_header = "X-GPT-Reason";
+  # Use JSON format for response
+  json = false;
 }
   ]])
   return
@@ -89,6 +91,7 @@ local settings = {
   symbols_to_trigger = nil, -- Exclude/include logic
   allow_passthrough = false,
   allow_ham = false,
+  json = false,
 }
 
 local function default_condition(task)
@@ -217,7 +220,7 @@ local function maybe_extract_json(str)
   return nil
 end
 
-local function default_conversion(task, input)
+local function default_openai_json_conversion(task, input)
   local parser = ucl.parser()
   local res, err = parser:parse_string(input)
   if not res then
@@ -273,14 +276,99 @@ local function default_conversion(task, input)
       rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
     end
 
-    return spam_score, reply.reason
+    return spam_score, reply.reason, {}
   end
 
   rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
   return
 end
 
-local function ollama_conversion(task, input)
+-- Assume that we have 3 lines: probability, reason, additional symbols
+local function default_openai_plain_conversion(task, input)
+  local parser = ucl.parser()
+  local res, err = parser:parse_string(input)
+  if not res then
+    rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+    return
+  end
+  local reply = parser:get_object()
+  if not reply then
+    rspamd_logger.errx(task, 'cannot get object from reply')
+    return
+  end
+
+  if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then
+    rspamd_logger.errx(task, 'no choices in reply')
+    return
+  end
+
+  local first_message = reply.choices[1].message.content
+
+  if not first_message then
+    rspamd_logger.errx(task, 'no content in the first message')
+    return
+  end
+  local lines = lua_util.str_split(first_message, '\n')
+  local first_line = lines[1] or ''
+  local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
+                                 :gsub("[^%d%.]", "")
+                                 :gsub("%.$", "")
+                                 :gsub("%.%..*", "")
+  local spam_score = tonumber(cleaned_line)
+  local reason = lines[2]
+  local symbols = lua_util.str_split(lines[3] or '', ',')
+
+  if spam_score then
+    return spam_score, reason, symbols
+  end
+
+  rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+  return
+end
+
+local function default_ollama_plain_conversion(task, input)
+  local parser = ucl.parser()
+  local res, err = parser:parse_string(input)
+  if not res then
+    rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+    return
+  end
+  local reply = parser:get_object()
+  if not reply then
+    rspamd_logger.errx(task, 'cannot get object from reply')
+    return
+  end
+
+  if type(reply.message) ~= 'table' then
+    rspamd_logger.errx(task, 'bad message in reply')
+    return
+  end
+
+  local first_message = reply.message.content
+
+  if not first_message then
+    rspamd_logger.errx(task, 'no content in the first message')
+    return
+  end
+  local lines = lua_util.str_split(first_message, '\n')
+  local first_line = lines[1] or ''
+  local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
+                                 :gsub("[^%d%.]", "")
+                                 :gsub("%.$", "")
+                                 :gsub("%.%..*", "")
+  local spam_score = tonumber(cleaned_line)
+  local reason = lines[2]
+  local symbols = lua_util.str_split(lines[3] or '', ',')
+
+  if spam_score then
+    return spam_score, reason, symbols
+  end
+
+  rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+  return
+end
+
+local function default_ollama_json_conversion(task, input)
   local parser = ucl.parser()
   local res, err = parser:parse_string(input)
   if not res then
@@ -456,7 +544,7 @@ local function default_llm_check(task)
         return
       end
 
-      local reply, reason = settings.reply_conversion(task, body)
+      local reply, reason, _symbols = settings.reply_conversion(task, body)
 
       results[idx].model = model
 
@@ -657,13 +745,17 @@ local types_map = {
   openai = {
     check = default_llm_check,
     condition = default_condition,
-    conversion = default_conversion,
+    conversion = function(is_json)
+      return is_json and default_openai_json_conversion or default_openai_plain_conversion
+    end,
     require_passkey = true,
   },
   ollama = {
     check = ollama_check,
     condition = default_condition,
-    conversion = ollama_conversion,
+    conversion = function(is_json)
+      return is_json and default_ollama_json_conversion or default_ollama_plain_conversion
+    end,
     require_passkey = false,
   },
 }
@@ -673,10 +765,11 @@ if opts then
   settings = lua_util.override_defaults(settings, opts)
 
   if not settings.prompt then
-    settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
-        "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
-        "output result as JSON with 'probability' field and " ..
-        "add 'reason' field with 1 sentence description why you have made that decision."
+    settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+        "FROM and url domains. Evaluate spam probability (0-1). " ..
+        "Output ONLY 2 lines:\n" ..
+        "1. Numeric score (0.00-1.00)\n" ..
+        "2. One-sentence reason citing strongest red flag"
   end
 
   if not settings.symbols_to_except then
@@ -700,7 +793,7 @@ if opts then
   if settings.reply_conversion then
     settings.reply_conversion = load(settings.reply_conversion)()
   else
-    settings.reply_conversion = llm_type.conversion
+    settings.reply_conversion = llm_type.conversion(settings.json)
   end
 
   if not settings.api_key and llm_type.require_passkey then