]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Move common stuff to a separate function
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 22:07:39 +0000 (23:07 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 18 Aug 2025 22:07:39 +0000 (23:07 +0100)
lualib/llm_common.lua [new file with mode: 0644]
lualib/plugins/neural/providers/llm.lua
src/plugins/lua/gpt.lua

diff --git a/lualib/llm_common.lua b/lualib/llm_common.lua
new file mode 100644 (file)
index 0000000..92d9a70
--- /dev/null
@@ -0,0 +1,72 @@
+--[[
+Common helpers for building LLM input content from a task
+]] --
+
+local lua_util = require "lua_util"
+local lua_mime = require "lua_mime"
+local fun = require "fun"
+
+local M = {}
+
+local function get_meta_llm_content(task)
+  local url_content = "Url domains: no urls found"
+  if task:has_urls() then
+    local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+    url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
+      return u:get_tld() or ''
+    end, urls or {})), ', ')
+  end
+
+  local from_or_empty = ((task:get_from('mime') or {})[1] or {})
+  local from_name = from_or_empty.name or ''
+  local from_addr = from_or_empty.addr or ''
+  local from_content = string.format('From: %s <%s>', from_name, from_addr)
+
+  return url_content, from_content
+end
+
+-- Build a single text payload suitable for LLM embeddings
+function M.build_llm_input(task, opts)
+  opts = opts or {}
+  local subject = task:get_subject() or ''
+  local url_content, from_content = get_meta_llm_content(task)
+
+  local sel_part = lua_mime.get_displayed_text_part(task)
+  if not sel_part then
+    return nil, nil
+  end
+
+  local nwords = sel_part:get_words_count() or 0
+  if nwords < 5 then
+    return nil, sel_part
+  end
+
+  local max_tokens = tonumber(opts.max_tokens) or 1024
+  local text_line
+  if nwords > max_tokens then
+    local words = sel_part:get_words('norm') or {}
+    if #words > max_tokens then
+      text_line = table.concat(words, ' ', 1, max_tokens)
+    else
+      text_line = table.concat(words, ' ')
+    end
+  else
+    text_line = sel_part:get_content_oneline() or ''
+  end
+
+  local content = table.concat({
+    'Subject: ' .. subject,
+    from_content,
+    url_content,
+    text_line,
+  }, '\n')
+
+  return content, sel_part
+end
+
+-- Backwards-compat alias
+M.build_embedding_input = M.build_llm_input
+
+M.get_meta_llm_content = get_meta_llm_content
+
+return M
index 7ef14228b2132dfd35c8965a5767d442fba88899..33301e9084cdf5fee6a18d90a01ed0bf98a5c0c5 100644 (file)
@@ -7,33 +7,15 @@ Supports minimal OpenAI- and Ollama-compatible embedding endpoints.
 local rspamd_http = require "rspamd_http"
 local rspamd_logger = require "rspamd_logger"
 local ucl = require "ucl"
-local lua_mime = require "lua_mime"
 local neural_common = require "plugins/neural"
 local lua_cache = require "lua_cache"
+local llm_common = require "llm_common"
 
 local N = "neural.llm"
 
-local function select_text(task, cfg)
-  local part = lua_mime.get_displayed_text_part(task)
-  if part then
-    local tp = part:get_text()
-    if tp then
-      -- Prefer UTF text content
-      local content = tp:get_content('raw_utf') or tp:get_content('raw')
-      if content and #content > 0 then
-        return content
-      end
-    end
-    -- Fallback to raw content
-    local rc = part:get_raw_content()
-    if type(rc) == 'userdata' then
-      rc = tostring(rc)
-    end
-    return rc
-  end
-
-  -- Fallback to subject if no text part
-  return task:get_subject() or ''
+local function select_text(task)
+  local content = llm_common.build_llm_input(task)
+  return content
 end
 
 local function compose_llm_settings(pcfg)
@@ -90,7 +72,7 @@ neural_common.register_provider('llm', {
       return nil
     end
 
-    local content = select_text(task, pcfg)
+    local content = select_text(task)
     if not content or #content == 0 then
       rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
       return nil
@@ -209,7 +191,7 @@ neural_common.register_provider('llm', {
     if not llm.model then
       return cont(nil)
     end
-    local content = select_text(task, pcfg)
+    local content = select_text(task)
     if not content or #content == 0 then
       return cont(nil)
     end
index a2e7dde3d06d3a46455462c5111261d20aa20879..8c533ec647f1b731f86bbe0cf7e13a72e3ea04fd 100644 (file)
@@ -71,9 +71,10 @@ local lua_util = require "lua_util"
 local rspamd_http = require "rspamd_http"
 local rspamd_logger = require "rspamd_logger"
 local lua_mime = require "lua_mime"
+local llm_common = require "llm_common"
 local lua_redis = require "lua_redis"
 local ucl = require "ucl"
-local fun = require "fun"
+-- local fun = require "fun" -- no longer needed after llm_common usage
 local lua_cache = require "lua_cache"
 
 -- Exclude checks if one of those is found
@@ -116,8 +117,8 @@ local categories_map = {}
 local settings = {
   type = 'openai',
   api_key = nil,
-       model = 'gpt-5-mini', -- or parallel model requests: [ 'gpt-5-mini', 'gpt-4o-mini' ],
-       model_parameters = {
+  model = 'gpt-5-mini', -- or parallel model requests: [ 'gpt-5-mini', 'gpt-4o-mini' ],
+  model_parameters = {
     ["gpt-5-mini"] = {
       max_completion_tokens = 1000,
     },
@@ -209,29 +210,19 @@ local function default_condition(task)
     end
   end
 
-  -- Check if we have text at all
-  local sel_part = lua_mime.get_displayed_text_part(task)
-
+  -- Unified LLM input building (subject/from/urls/body one-line)
+  local content, sel_part = llm_common.build_llm_input(task, { max_tokens = settings.max_tokens })
   if not sel_part then
     return false, 'no text part found'
   end
-
-  -- Check limits and size sanity
-  local nwords = sel_part:get_words_count()
-
-  if nwords < 5 then
-    return false, 'less than 5 words'
-  end
-
-  if nwords > settings.max_tokens then
-    -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`)
-    local words = sel_part:get_words('norm')
-    nwords = #words
-    if nwords > settings.max_tokens then
-      return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part
+  if not content or #content == 0 then
+    local nwords = sel_part:get_words_count() or 0
+    if nwords < 5 then
+      return false, 'less than 5 words'
     end
+    return false, 'no content to send'
   end
-  return true, sel_part:get_content_oneline(), sel_part
+  return true, content, sel_part
 end
 
 local function maybe_extract_json(str)
@@ -617,22 +608,7 @@ local function check_consensus_and_insert_results(task, results, sel_part)
   end
 end
 
-local function get_meta_llm_content(task)
-  local url_content = "Url domains: no urls found"
-  if task:has_urls() then
-    local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
-    url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
-      return u:get_tld() or ''
-    end, urls or {})), ', ')
-  end
-
-  local from_or_empty = ((task:get_from('mime') or E)[1] or E)
-  local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
-  lua_util.debugm(N, task, "gpt urls: %s", url_content)
-  lua_util.debugm(N, task, "gpt from: %s", from_content)
-
-  return url_content, from_content
-end
+-- get_meta_llm_content moved to llm_common
 
 local function check_llm_uncached(task, content, sel_part)
   return settings.specific_check(task, content, sel_part)
@@ -700,27 +676,12 @@ local function openai_check(task, content, sel_part)
     end
   end
 
-  local from_content, url_content = get_meta_llm_content(task)
-
-
   local body_base = {
     messages = {
       {
         role = 'system',
         content = settings.prompt
       },
-      {
-        role = 'user',
-        content = 'Subject: ' .. (task:get_subject() or ''),
-      },
-      {
-        role = 'user',
-        content = from_content,
-      },
-      {
-        role = 'user',
-        content = url_content,
-      },
       {
         role = 'user',
         content = content
@@ -741,13 +702,13 @@ local function openai_check(task, content, sel_part)
     -- Fresh body for each model
     local body = lua_util.deepcopy(body_base)
 
-               -- Merge model-specific parameters into body
-               local params = settings.model_parameters[model]
-               if params then
-                       for k, v in pairs(params) do
-                               body[k] = v
-                       end
-               end
+    -- Merge model-specific parameters into body
+    local params = settings.model_parameters[model]
+    if params then
+      for k, v in pairs(params) do
+        body[k] = v
+      end
+    end
 
     -- Conditionally add response_format
     if settings.include_response_format then
@@ -815,8 +776,6 @@ local function ollama_check(task, content, sel_part)
     end
   end
 
-  local from_content, url_content = get_meta_llm_content(task)
-
   if type(settings.model) == 'string' then
     settings.model = { settings.model }
   end
@@ -831,18 +790,6 @@ local function ollama_check(task, content, sel_part)
         role = 'system',
         content = settings.prompt
       },
-      {
-        role = 'user',
-        content = 'Subject: ' .. task:get_subject() or '',
-      },
-      {
-        role = 'user',
-        content = from_content,
-      },
-      {
-        role = 'user',
-        content = url_content,
-      },
       {
         role = 'user',
         content = content