]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Allow additional categories to be defined in GPT 5356/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 12:32:14 +0000 (12:32 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 12:32:14 +0000 (12:32 +0000)
src/plugins/lua/gpt.lua

index f605b702a0993bf1e464c5e532c2b964a4dcc4bb..625450fd94018fa89e5708dec97aaf1f9607467c 100644 (file)
@@ -77,6 +77,32 @@ local default_symbols_to_except = {
   BOUNCE = -1,
 }
 
+local default_extra_symbols = {
+  GPT_MARKETING = {
+    score = 0.0,
+    description = 'GPT model detected marketing content',
+    category = 'marketing',
+  },
+  GPT_PHISHING = {
+    score = 3.0,
+    description = 'GPT model detected phishing content',
+    category = 'phishing',
+  },
+  GPT_SCAM = {
+    score = 3.0,
+    description = 'GPT model detected scam content',
+    category = 'scam',
+  },
+  GPT_MALWARE = {
+    score = 3.0,
+    description = 'GPT model detected malware content',
+    category = 'malware',
+  },
+}
+
+-- Should be filled from extra symbols
+local categories_map = {}
+
 local settings = {
   type = 'openai',
   api_key = nil,
@@ -95,6 +121,7 @@ local settings = {
   allow_ham = false,
   json = false,
   redis_cache_expire = 3600 * 24,
+  extra_symbols = nil,
 }
 local redis_params
 
@@ -287,6 +314,14 @@ local function default_openai_json_conversion(task, input)
   return
 end
 
+-- Remove what we don't need
+local function clean_reply_line(line)
+  if not line then
+    return ''
+  end
+  return lua_util.str_trim(line):gsub("^%d%.%s+", "")
+end
+
 -- Assume that we have 3 lines: probability, reason, additional symbols
 local function default_openai_plain_conversion(task, input)
   local parser = ucl.parser()
@@ -313,17 +348,13 @@ local function default_openai_plain_conversion(task, input)
     return
   end
   local lines = lua_util.str_split(first_message, '\n')
-  local first_line = lines[1] or ''
-  local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
-                                 :gsub("[^%d%.]", "")
-                                 :gsub("%.$", "")
-                                 :gsub("%.%..*", "")
-  local spam_score = tonumber(cleaned_line)
-  local reason = lines[2]
-  local symbols = lua_util.str_split(lines[3] or '', ',')
+  local first_line = clean_reply_line(lines[1])
+  local spam_score = tonumber(first_line)
+  local reason = clean_reply_line(lines[2])
+  local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
 
   if spam_score then
-    return spam_score, reason, symbols
+    return spam_score, reason, categories
   end
 
   rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
@@ -355,20 +386,16 @@ local function default_ollama_plain_conversion(task, input)
     return
   end
   local lines = lua_util.str_split(first_message, '\n')
-  local first_line = lines[1] or ''
-  local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
-                                 :gsub("[^%d%.]", "")
-                                 :gsub("%.$", "")
-                                 :gsub("%.%..*", "")
-  local spam_score = tonumber(cleaned_line)
-  local reason = lines[2]
-  local symbols = lua_util.str_split(lines[3] or '', ',')
+  local first_line = clean_reply_line(lines[1])
+  local spam_score = tonumber(first_line)
+  local reason = clean_reply_line(lines[2])
+  local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
 
   if spam_score then
-    return spam_score, reason, symbols
+    return spam_score, reason, categories
   end
 
-  rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+  rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1])
   return
 end
 
@@ -468,6 +495,15 @@ local function maybe_save_cache(task, result, sel_part)
       'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
 end
 
+local function process_categories(task, categories)
+  for _, category in ipairs(categories) do
+    local sym = categories_map[category:lower()]
+    if sym then
+      task:insert_result(sym.name, 1.0)
+    end
+  end
+end
+
 local function insert_results(task, result, sel_part)
   if not result.probability then
     rspamd_logger.errx(task, 'no probability in result')
@@ -478,6 +514,10 @@ local function insert_results(task, result, sel_part)
     if settings.autolearn then
       task:set_flag("learn_spam")
     end
+
+    if result.categories then
+      process_categories(task, result.categories)
+    end
   else
     if result.reason and settings.reason_header then
       lua_mime.modify_headers(task,
@@ -487,6 +527,9 @@ local function insert_results(task, result, sel_part)
     if settings.autolearn then
       task:set_flag("learn_ham")
     end
+    if result.categories then
+      process_categories(task, result.categories)
+    end
   end
   maybe_save_cache(task, result, sel_part)
 end
@@ -517,7 +560,7 @@ local function check_consensus_and_insert_results(task, results, sel_part)
       end
 
       if result.reason then
-        table.insert(reasons, result.reason)
+        table.insert(reasons, result)
       end
     end
   end
@@ -528,13 +571,15 @@ local function check_consensus_and_insert_results(task, results, sel_part)
   if nspam > nham and max_spam_prob > 0.75 then
     insert_results(task, {
       probability = max_spam_prob,
-      reason = reason,
+      reason = reason.reason,
+      categories = reason.categories,
     },
         sel_part)
   elseif nham > nspam and max_ham_prob < 0.25 then
     insert_results(task, {
       probability = max_ham_prob,
-      reason = reason,
+      reason = reason.reason,
+      categories = reason.categories,
     },
         sel_part)
   else
@@ -620,7 +665,7 @@ local function openai_check(task, content, sel_part)
         return
       end
 
-      local reply, reason, _symbols = settings.reply_conversion(task, body)
+      local reply, reason, categories = settings.reply_conversion(task, body)
 
       results[idx].model = model
 
@@ -628,6 +673,10 @@ local function openai_check(task, content, sel_part)
         results[idx].success = true
         results[idx].probability = reply
         results[idx].reason = reason
+
+        if categories then
+          results[idx].categories = categories
+        end
       end
 
       check_consensus_and_insert_results(task, results, sel_part)
@@ -853,18 +902,14 @@ if opts then
         })
   end
 
-  if not settings.prompt then
-    settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
-        "FROM and url domains. Evaluate spam probability (0-1). " ..
-        "Output ONLY 2 lines:\n" ..
-        "1. Numeric score (0.00-1.00)\n" ..
-        "2. One-sentence reason citing strongest red flag"
-  end
-
   if not settings.symbols_to_except then
     settings.symbols_to_except = default_symbols_to_except
   end
 
+  if not settings.extra_symbols then
+    settings.extra_symbols = default_extra_symbols
+  end
+
   local llm_type = types_map[settings.type]
   if not llm_type then
     rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
@@ -906,7 +951,7 @@ if opts then
     name = 'GPT_SPAM',
     type = 'virtual',
     parent = id,
-    score = 5.0,
+    score = 3.0,
   })
   rspamd_config:register_symbol({
     name = 'GPT_HAM',
@@ -914,4 +959,35 @@ if opts then
     parent = id,
     score = -2.0,
   })
+
+  if settings.extra_symbols then
+    for sym, data in pairs(settings.extra_symbols) do
+      rspamd_config:register_symbol({
+        name = sym,
+        type = 'virtual',
+        parent = id,
+        score = data.score,
+        description = data.description,
+      })
+      data.name = sym
+      categories_map[data.category] = data
+    end
+  end
+
+  if not settings.prompt then
+    if settings.extra_symbols then
+      settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+          "FROM and url domains. Evaluate spam probability (0-1). " ..
+          "Output ONLY 3 lines:\n" ..
+          "1. Numeric score (0.00-1.00)\n" ..
+          "2. One-sentence reason citing strongest red flag\n" ..
+          "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ')
+    else
+      settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+          "FROM and url domains. Evaluate spam probability (0-1). " ..
+          "Output ONLY 2 lines:\n" ..
+          "1. Numeric score (0.00-1.00)\n" ..
+          "2. One-sentence reason citing strongest red flag\n"
+    end
+  end
 end