]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add checks to decide if we need a GPT check
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 28 Jun 2024 10:18:40 +0000 (11:18 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 28 Jun 2024 10:18:40 +0000 (11:18 +0100)
src/plugins/lua/gpt.lua

index 01d1c94b43abe2e9f8991a1096d80cf987b61f55..7f4d55b173ac4d5c2e78d925c3f55858509bb591 100644 (file)
@@ -28,7 +28,7 @@ gpt {
   # Model name
   model = "gpt-3.5-turbo";
   # Maximum tokens to generate
-  max_tokens = 100;
+  max_tokens = 1000;
   # Temperature for sampling
   temperature = 0.7;
   # Top p for sampling
@@ -44,7 +44,7 @@ gpt {
   # Reply conversion (lua code)
   reply_conversion = "xxx";
   # URL for the API
-  url = "https://api.openai.com/v1/chat/completions";ß
+  url = "https://api.openai.com/v1/chat/completions";
 }
   ]])
   return
@@ -59,7 +59,7 @@ local settings = {
   type = 'openai',
   api_key = nil,
   model = 'gpt-3.5-turbo',
-  max_tokens = 100,
+  max_tokens = 1000,
   temperature = 0.7,
   top_p = 0.9,
   timeout = 10,
@@ -69,8 +69,80 @@ local settings = {
   url = 'https://api.openai.com/v1/chat/completions',
 }
 
+-- Exclude checks if one of those is found
+local symbols_to_except = {
+  'BAYES_SPAM', -- We already know that it is a spam, so we can safely skip it, but no same logic for HAM!
+  'WHITELIST_SPF',
+  'WHITELIST_DKIM',
+  'WHITELIST_DMARC',
+  'FUZZY_DENIED',
+}
+
 local function default_condition(task)
-  return true
+  -- Check result
+  -- 1) Skip passthrough
+  -- 2) Skip already decided as spam
+  -- 3) Skip already decided as ham
+  local result = task:get_metric_result()
+  if result then
+    if result.passthrough then
+      return false, 'passthrough'
+    end
+    local score = result.score
+    local action = result.action
+
+    if action == 'reject' and result.npositive > 1 then
+      return true, 'already decided as spam'
+    end
+
+    if action == 'no action' and score < 0 then
+      return true, 'negative score, already decided as ham'
+    end
+  end
+  -- We also exclude some symbols
+  for _, s in ipairs(symbols_to_except) do
+    if task:has_symbol(s) then
+      return false, 'skip as "' .. s .. '" is found'
+    end
+  end
+
+  -- Check if we have text at all
+  local mp = task:get_parts() or {}
+  local sel_part
+  for _, mime_part in ipairs(mp) do
+    if mime_part:is_text() then
+      local part = mime_part:get_text()
+      if part:is_html() then
+        -- We prefer html content
+        sel_part = part
+      elseif not sel_part then
+        sel_part = part
+      end
+    end
+  end
+
+  if not sel_part then
+    return false, 'no text part found'
+  end
+
+  -- Check limits and size sanity
+  local nwords = sel_part:get_words_count()
+
+  if nwords < 5 then
+    return false, 'less than 5 words'
+  end
+
+  if nwords > settings.max_tokens then
+    -- We need to truncate words
+    local words = sel_part:get_words('norm')
+    -- Trim something that does not fit
+    for i = nwords, settings.max_tokens, -1 do
+      rawset(words, i, nil)
+    end
+    return true, table.concat(words, ' ')
+  else
+    return true, sel_part:get_content_oneline()
+  end
 end
 
 local function default_conversion(task, input)
@@ -112,10 +184,20 @@ local function default_conversion(task, input)
 end
 
 local function openai_gpt_check(task)
-  if not settings.condition(task) then
-    lua_util.debugm(N, task, "skip checking gpt as the condition is not met")
+  local ret, content = settings.condition(task)
+
+  if not ret then
+    lua_util.info(N, task, "skip checking gpt as the condition is not met: %s", content)
     return
   end
+
+  if not content then
+    lua_util.debugm(N, task, "no content to send to gpt classification")
+    return
+  end
+
+  lua_util.debugm(N, task, "sending content to gpt: %s", content)
+
   local upstream
 
   local function on_reply(err, code, body)
@@ -149,24 +231,6 @@ local function openai_gpt_check(task)
     -- TODO: add autolearn here
   end
 
-  local mp = task:get_parts() or {}
-  local content
-  for _, mime_part in ipairs(mp) do
-    if mime_part:is_text() then
-      local part = mime_part:get_text()
-      if part:is_html() then
-        -- We prefer html content
-        content = part:get_content_oneline()
-      elseif not content then
-        content = part:get_content_oneline()
-      end
-    end
-  end
-
-  if not content then
-    lua_util.debugm(N, task, "no content to send to gpt classification")
-  end
-
   local body = {
     model = settings.model,
     max_tokens = settings.max_tokens,