]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add context_augment hook to GPT module
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 14 Mar 2026 22:13:42 +0000 (22:13 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 14 Mar 2026 22:13:42 +0000 (22:13 +0000)
Add a new `context_augment` configuration option that accepts Lua code
returning an async function(task, content, callback).  The callback
receives a string that gets injected as additional context into the
LLM prompt alongside existing user/domain and search contexts.

This enables external Lua code to enrich LLM requests with arbitrary
context — e.g., Telegram channel topic and recent messages for
community spam detection.

The augment function runs in parallel with other context fetchers
and supports async operations (Redis, HTTP).

src/plugins/lua/gpt.lua

index aa1d3c0ce69aff85faffc893800abca07c8c9b54..310faf23b799365458f3497af3351285ce4b3f34 100644 (file)
@@ -52,6 +52,10 @@ if confighelp then
   autolearn = true;
   # Reply conversion (lua code)
   reply_conversion = "xxx";
+  # Custom context augmentation (lua code returning function(task, content, callback))
+  # The callback receives a string to inject as additional context into the LLM prompt.
+  # Supports async operations (Redis, HTTP) — call callback(string) when ready.
+  # context_augment = "return function(task, content, cb) cb('Channel topic: programming') end";
   # URL for the API
   url = "https://api.openai.com/v1/chat/completions";
   # Check messages with passthrough result
@@ -202,6 +206,7 @@ local settings = {
   json = false,
   extra_symbols = nil,
   cache_prefix = REDIS_PREFIX,
+  context_augment = nil, -- Lua code returning function(task, content, callback) for custom context
   request_timeout = nil, -- Optional: pass request timeout to server (in seconds)
   -- user/domain context options (nested table forwarded to llm_context)
   context = {
@@ -1195,23 +1200,28 @@ local function gpt_check(task)
 
   -- Check if we need to fetch search context
   local search_context_enabled = is_search_context_enabled_for_task(task)
+  local has_augment = settings.context_augment ~= nil
 
-  if context_enabled or search_context_enabled then
+  if context_enabled or search_context_enabled or has_augment then
     local pending_fetches = 0
     local user_context_snippet = nil
     local search_context_snippet = nil
+    local augment_context_snippet = nil
 
     local function maybe_proceed()
       if pending_fetches == 0 then
-        -- Combine contexts
-        local combined_context = nil
-        if user_context_snippet and search_context_snippet then
-          combined_context = user_context_snippet .. "\n\n" .. search_context_snippet
-        elseif user_context_snippet then
-          combined_context = user_context_snippet
-        elseif search_context_snippet then
-          combined_context = search_context_snippet
+        -- Combine all context snippets
+        local parts = {}
+        if user_context_snippet then
+          parts[#parts + 1] = user_context_snippet
         end
+        if search_context_snippet then
+          parts[#parts + 1] = search_context_snippet
+        end
+        if augment_context_snippet then
+          parts[#parts + 1] = augment_context_snippet
+        end
+        local combined_context = #parts > 0 and table.concat(parts, "\n\n") or nil
         proceed(combined_context)
       end
     end
@@ -1234,6 +1244,20 @@ local function gpt_check(task)
       end, N)
     end
 
+    if has_augment then
+      pending_fetches = pending_fetches + 1
+      local ok, err = pcall(settings.context_augment, task, content, function(snippet)
+        augment_context_snippet = snippet
+        pending_fetches = pending_fetches - 1
+        maybe_proceed()
+      end)
+      if not ok then
+        rspamd_logger.errx(task, 'context_augment callback failed: %s', err)
+        pending_fetches = pending_fetches - 1
+        maybe_proceed()
+      end
+    end
+
     -- If no fetches were initiated, proceed immediately
     if pending_fetches == 0 then
       proceed(nil)
@@ -1299,6 +1323,16 @@ if opts then
     settings.reply_conversion = llm_type.conversion(settings.json)
   end
 
+  if settings.context_augment then
+    local augment_fn, augment_err = load(settings.context_augment)
+    if augment_fn then
+      settings.context_augment = augment_fn()
+    else
+      rspamd_logger.warnx(rspamd_config, 'failed to compile context_augment: %s', augment_err)
+      settings.context_augment = nil
+    end
+  end
+
   if not settings.api_key and llm_type.require_passkey then
     rspamd_logger.warnx(rspamd_config, 'no api_key is specified for LLM type %s, disabling module', settings.type)
     lua_util.disable_module(N, "config")