]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add bidirectional context support for LLM
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Oct 2025 13:30:20 +0000 (14:30 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 2 Oct 2025 13:30:20 +0000 (14:30 +0100)
* Unify context for incoming and outgoing mail
* Same identity used for authenticated/local sender and recipient
* Follows replies module pattern for direction detection
* Make llm_context.lua module-agnostic with debug_module parameter
* Improve userdata handling (use :sub instead of string.sub)
* Add nil-safety to all debug logging calls
* Add cache expiration timestamps to context logs

lualib/llm_context.lua

index 4e79c632a224241b336d2d1e1c688822ccaa104d..c01850b1a58342da17763c01e425991e90cf573e 100644 (file)
@@ -2,8 +2,8 @@
 Context management for LLM-based spam detection
 
 Provides:
-  - fetch(task, redis_params, opts, callback): load context JSON from Redis and format prompt snippet
-  - update_after_classification(task, redis_params, opts, result, sel_part): update context after LLM result
+  - fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
+  - update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result
 
 Opts (all optional, safe defaults applied):
   enabled: boolean
@@ -17,6 +17,8 @@ Opts (all optional, safe defaults applied):
   summary_max_chars: number (truncate stored text)
   flagged_phrases: array of strings (case-insensitive match)
   last_labels_count: number
+
+debug_module: optional string, module name for debug logging (default: 'llm_context')
 ]]
 
 local M = {}
@@ -56,49 +58,86 @@ local function to_seconds(v)
   return tonumber(v) or 0
 end
 
-local function get_principal_recipient(task)
-  return task:get_principal_recipient()
-end
-
 local function get_domain_from_addr(addr)
   if not addr then return nil end
   return string.match(addr, '.*@(.+)')
 end
 
-local function compute_identity(task, opts)
-  local scope = opts.level or DEFAULTS.level
+-- Determine our user/domain - same identity for both incoming and outgoing mail
+local function get_our_identity(task, scope)
+  -- For outgoing mail: authenticated user or sender from local network
+  -- For incoming mail: principal recipient
+  local user = task:get_user()
+  local ip = task:get_ip()
+  local is_outgoing = user or (ip and ip:is_local())
+
   local identity
   if scope == 'user' then
-    identity = task:get_user() or get_principal_recipient(task)
-    if not identity then
-      local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
-      identity = from
+    if is_outgoing then
+      -- Outgoing: use sender (authenticated user or from address)
+      identity = user or task:get_reply_sender()
+      if not identity then
+        local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
+        identity = from
+      end
+    else
+      -- Incoming: use recipient
+      identity = task:get_principal_recipient()
     end
   elseif scope == 'domain' then
-    local rcpt = get_principal_recipient(task)
-    identity = get_domain_from_addr(rcpt)
-    if not identity then
-      identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
+    if is_outgoing then
+      -- Outgoing: domain of sender
+      if user then
+        identity = get_domain_from_addr(user)
+      end
+      if not identity then
+        identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
+      end
+    else
+      -- Incoming: domain of recipient
+      local rcpt = task:get_principal_recipient()
+      identity = get_domain_from_addr(rcpt)
     end
   elseif scope == 'esld' then
-    local rcpt = get_principal_recipient(task)
-    local d = get_domain_from_addr(rcpt)
-    if d then
-      identity = rspamd_util.get_tld(d)
-    end
-    if not identity then
-      local fd = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
-      if fd then identity = rspamd_util.get_tld(fd) end
+    if is_outgoing then
+      -- Outgoing: eSLD of sender domain
+      local d
+      if user then
+        d = get_domain_from_addr(user)
+      end
+      if not d then
+        d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
+      end
+      if d then identity = rspamd_util.get_tld(d) end
+    else
+      -- Incoming: eSLD of recipient domain
+      local rcpt = task:get_principal_recipient()
+      local d = get_domain_from_addr(rcpt)
+      if d then
+        identity = rspamd_util.get_tld(d)
+      end
     end
-  else
-    scope = 'user'
-    identity = task:get_user() or get_principal_recipient(task)
   end
 
+  return identity
+end
+
+local function compute_identity(task, opts, debug_module)
+  local N = debug_module or 'llm_context'
+  local scope = opts.level or DEFAULTS.level
+  local identity = get_our_identity(task, scope)
+
   if not identity or identity == '' then
     return nil
   end
 
+  -- Log direction for debugging
+  local user = task:get_user()
+  local ip = task:get_ip()
+  local is_outgoing = user or (ip and ip:is_local())
+  lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
+    scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))
+
   local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
   local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
   local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)
@@ -110,10 +149,17 @@ local function compute_identity(task, opts)
   }
 end
 
-local function parse_json(str)
-  if not str or str == '' then return nil end
+local function parse_json(data)
+  if not data then return nil end
+  -- Redis can return userdata nil or empty string
+  if type(data) == 'userdata' then
+    data = tostring(data)
+  end
+  if type(data) ~= 'string' or data == '' then
+    return nil
+  end
   local parser = ucl.parser()
-  local ok, err = parser:parse_string(str)
+  local ok, err = parser:parse_text(data)
   if not ok then return nil, err end
   return parser:get_object()
 end
@@ -129,7 +175,7 @@ end
 local function truncate_text(txt, limit)
   if not txt then return '' end
   if #txt <= limit then return txt end
-  return string.sub(txt, 1, limit)
+  return txt:sub(1, limit)
 end
 
 local function has_flag(flags, flag_name)
@@ -314,7 +360,8 @@ local function format_context_prompt(ctx)
   return table.concat(parts, '\n')
 end
 
-function M.fetch(task, redis_params, opts, callback)
+function M.fetch(task, redis_params, opts, callback, debug_module)
+  local N = debug_module or 'llm_context'
   opts = lua_util.override_defaults(DEFAULTS, opts or {})
   if not opts.enabled then
     callback(nil, nil, nil)
@@ -325,22 +372,28 @@ function M.fetch(task, redis_params, opts, callback)
     return
   end
 
-  local ident = compute_identity(task, opts)
+  local ident = compute_identity(task, opts, N)
   if not ident then
+    lua_util.debugm(N, task, 'no identity computed, skipping context')
     callback('no identity', nil, nil)
     return
   end
 
+  lua_util.debugm(N, task, 'fetching context for %s: %s',
+    tostring(ident.scope), tostring(ident.identity))
+
   local function on_get(err, data)
     if err then
-      rspamd_logger.errx(task, 'llm_context: get failed: %s', err)
+      rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
       callback(err, nil, nil)
       return
     end
     local ctx
     if data then
+      lua_util.debugm(N, task, 'got context data from redis, parsing')
       ctx = ensure_defaults(select(1, parse_json(data)) or {})
     else
+      lua_util.debugm(N, task, 'no context data in redis, using empty')
       ctx = ensure_defaults({})
     end
 
@@ -348,12 +401,14 @@ function M.fetch(task, redis_params, opts, callback)
     local min_msgs = opts.min_messages or DEFAULTS.min_messages
     local msg_count = #(ctx.recent_messages or {})
     if msg_count < min_msgs then
-      lua_util.debugm('llm_context', task, 'context has only %s messages (min: %s), not injecting into prompt',
-        msg_count, min_msgs)
+      lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
+        tostring(msg_count), tostring(min_msgs))
       callback(nil, ctx, nil) -- return ctx but no prompt snippet
       return
     end
 
+    lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
+      tostring(msg_count))
     local prompt_snippet = format_context_prompt(ctx)
     callback(nil, ctx, prompt_snippet)
   end
@@ -364,19 +419,22 @@ function M.fetch(task, redis_params, opts, callback)
   end
 end
 
-function M.update_after_classification(task, redis_params, opts, result, sel_part)
+function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
+  local N = debug_module or 'llm_context'
   opts = lua_util.override_defaults(DEFAULTS, opts or {})
   if not opts.enabled then return end
   if not redis_params then return end
 
-  local ident = compute_identity(task, opts)
+  local ident = compute_identity(task, opts, N)
   if not ident then return end
 
   local function on_get(err, data)
     if err then
-      rspamd_logger.errx(task, 'llm_context: get for update failed: %s', err)
+      rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
       return
     end
+    lua_util.debugm(N, task, 'updating context for %s: %s',
+      tostring(ident.scope), tostring(ident.identity))
     local ctx = ensure_defaults(select(1, parse_json(data)) or {})
 
     local msg = build_message_summary(task, sel_part, opts)
@@ -413,9 +471,33 @@ function M.update_after_classification(task, redis_params, opts, result, sel_par
 
     local payload = encode_json(ctx)
     local ttl = to_seconds(opts.ttl)
+    local expire_at = now() + ttl
+
+    -- Log what we're storing in context
+    lua_util.debugm(N, task,
+      'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
+      tostring(ident.identity or '(none)'),
+      tostring(#ctx.recent_messages),
+      table.concat(ctx.last_spam_labels or {}, ','),
+      table.concat(ctx.top_senders or {}, ','),
+      table.concat(ctx.flagged_phrases or {}, ','),
+      tostring(#payload),
+      os.date('%Y-%m-%d %H:%M:%S', expire_at))
+
+    if msg then
+      lua_util.debugm(N, task,
+        'added message: from=%s, subject=%s, keywords=%s',
+        tostring(msg.from or '(none)'),
+        tostring(msg.subject or '(none)'),
+        table.concat(msg.keywords or {}, ','))
+    end
+
     local function on_set(set_err)
       if set_err then
-        rspamd_logger.errx(task, 'llm_context: set failed: %s', set_err)
+        rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
+      else
+        lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
+          tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
       end
     end
     local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',