From 38c48e5a62645348ceb198d3eeb7a659866e76c4 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 2 Oct 2025 14:30:20 +0100 Subject: [PATCH] [Feature] Add bidirectional context support for LLM * Unify context for incoming and outgoing mail * Same identity used for authenticated/local sender and recipient * Follows replies module pattern for direction detection * Make llm_context.lua module-agnostic with debug_module parameter * Improve userdata handling (use :sub instead of string.sub) * Add nil-safety to all debug logging calls * Add cache expiration timestamps to context logs --- lualib/llm_context.lua | 162 +++++++++++++++++++++++++++++++---------- 1 file changed, 122 insertions(+), 40 deletions(-) diff --git a/lualib/llm_context.lua b/lualib/llm_context.lua index 4e79c632a2..c01850b1a5 100644 --- a/lualib/llm_context.lua +++ b/lualib/llm_context.lua @@ -2,8 +2,8 @@ Context management for LLM-based spam detection Provides: - - fetch(task, redis_params, opts, callback): load context JSON from Redis and format prompt snippet - - update_after_classification(task, redis_params, opts, result, sel_part): update context after LLM result + - fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet + - update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result Opts (all optional, safe defaults applied): enabled: boolean @@ -17,6 +17,8 @@ Opts (all optional, safe defaults applied): summary_max_chars: number (truncate stored text) flagged_phrases: array of strings (case-insensitive match) last_labels_count: number + +debug_module: optional string, module name for debug logging (default: 'llm_context') ]] local M = {} @@ -56,49 +58,86 @@ local function to_seconds(v) return tonumber(v) or 0 end -local function get_principal_recipient(task) - return task:get_principal_recipient() -end - local function get_domain_from_addr(addr) if not addr then return nil end return string.match(addr, '.*@(.+)') end -local function compute_identity(task, opts) - local scope = opts.level or DEFAULTS.level +-- Determine our user/domain - same identity for both incoming and outgoing mail +local function get_our_identity(task, scope) + -- For outgoing mail: authenticated user or sender from local network + -- For incoming mail: principal recipient + local user = task:get_user() + local ip = task:get_ip() + local is_outgoing = user or (ip and ip:is_local()) + local identity if scope == 'user' then - identity = task:get_user() or get_principal_recipient(task) - if not identity then - local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'] - identity = from + if is_outgoing then + -- Outgoing: use sender (authenticated user or from address) + identity = user or task:get_reply_sender() + if not identity then + local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'] + identity = from + end + else + -- Incoming: use recipient + identity = task:get_principal_recipient() end elseif scope == 'domain' then - local rcpt = get_principal_recipient(task) - identity = get_domain_from_addr(rcpt) - if not identity then - identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] + if is_outgoing then + -- Outgoing: domain of sender + if user then + identity = get_domain_from_addr(user) + end + if not identity then + identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] + end + else + -- Incoming: domain of recipient + local rcpt = task:get_principal_recipient() + identity = get_domain_from_addr(rcpt) end elseif scope == 'esld' then - local rcpt = get_principal_recipient(task) - local d = get_domain_from_addr(rcpt) - if d then - identity = rspamd_util.get_tld(d) - end - if not identity then - local fd = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] - if fd then identity = rspamd_util.get_tld(fd) end + if is_outgoing then + -- Outgoing: eSLD of sender domain + local d + if user then + d = get_domain_from_addr(user) + end + if not d then + d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] + end + if d then identity = rspamd_util.get_tld(d) end + else + -- Incoming: eSLD of recipient domain + local rcpt = task:get_principal_recipient() + local d = get_domain_from_addr(rcpt) + if d then + identity = rspamd_util.get_tld(d) + end end - else - scope = 'user' - identity = task:get_user() or get_principal_recipient(task) end + return identity +end + +local function compute_identity(task, opts, debug_module) + local N = debug_module or 'llm_context' + local scope = opts.level or DEFAULTS.level + local identity = get_our_identity(task, scope) + if not identity or identity == '' then return nil end + -- Log direction for debugging + local user = task:get_user() + local ip = task:get_ip() + local is_outgoing = user or (ip and ip:is_local()) + lua_util.debugm(N, task, 'computed identity for %s (%s): %s', + scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity)) + local key_prefix = opts.key_prefix or DEFAULTS.key_prefix local key_suffix = opts.key_suffix or DEFAULTS.key_suffix local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix) @@ -110,10 +149,17 @@ local function compute_identity(task, opts) } end -local function parse_json(str) - if not str or str == '' then return nil end +local function parse_json(data) + if not data then return nil end + -- Redis can return userdata nil or empty string + if type(data) == 'userdata' then + data = tostring(data) + end + if type(data) ~= 'string' or data == '' then + return nil + end local parser = ucl.parser() - local ok, err = parser:parse_string(str) + local ok, err = parser:parse_text(data) if not ok then return nil, err end return parser:get_object() end @@ -129,7 +175,7 @@ end local function truncate_text(txt, limit) if not txt then return '' end if #txt <= limit then return txt end - return string.sub(txt, 1, limit) + return txt:sub(1, limit) end local function has_flag(flags, flag_name) @@ -314,7 +360,8 @@ local function format_context_prompt(ctx) return table.concat(parts, '\n') end -function M.fetch(task, redis_params, opts, callback) +function M.fetch(task, redis_params, opts, callback, debug_module) + local N = debug_module or 'llm_context' opts = lua_util.override_defaults(DEFAULTS, opts or {}) if not opts.enabled then callback(nil, nil, nil) @@ -325,22 +372,28 @@ function M.fetch(task, redis_params, opts, callback) return end - local ident = compute_identity(task, opts) + local ident = compute_identity(task, opts, N) if not ident then + lua_util.debugm(N, task, 'no identity computed, skipping context') callback('no identity', nil, nil) return end + lua_util.debugm(N, task, 'fetching context for %s: %s', + tostring(ident.scope), tostring(ident.identity)) + local function on_get(err, data) if err then - rspamd_logger.errx(task, 'llm_context: get failed: %s', err) + rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err)) callback(err, nil, nil) return end local ctx if data then + lua_util.debugm(N, task, 'got context data from redis, parsing') ctx = ensure_defaults(select(1, parse_json(data)) or {}) else + lua_util.debugm(N, task, 'no context data in redis, using empty') ctx = ensure_defaults({}) end @@ -348,12 +401,14 @@ function M.fetch(task, redis_params, opts, callback) local min_msgs = opts.min_messages or DEFAULTS.min_messages local msg_count = #(ctx.recent_messages or {}) if msg_count < min_msgs then - lua_util.debugm('llm_context', task, 'context has only %s messages (min: %s), not injecting into prompt', - msg_count, min_msgs) + lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt', + tostring(msg_count), tostring(min_msgs)) callback(nil, ctx, nil) -- return ctx but no prompt snippet return end + lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet', + tostring(msg_count)) local prompt_snippet = format_context_prompt(ctx) callback(nil, ctx, prompt_snippet) end @@ -364,19 +419,22 @@ function M.fetch(task, redis_params, opts, callback) end end -function M.update_after_classification(task, redis_params, opts, result, sel_part) +function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module) + local N = debug_module or 'llm_context' opts = lua_util.override_defaults(DEFAULTS, opts or {}) if not opts.enabled then return end if not redis_params then return end - local ident = compute_identity(task, opts) + local ident = compute_identity(task, opts, N) if not ident then return end local function on_get(err, data) if err then - rspamd_logger.errx(task, 'llm_context: get for update failed: %s', err) + rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err)) return end + lua_util.debugm(N, task, 'updating context for %s: %s', + tostring(ident.scope), tostring(ident.identity)) local ctx = ensure_defaults(select(1, parse_json(data)) or {}) local msg = build_message_summary(task, sel_part, opts) @@ -413,9 +471,33 @@ function M.update_after_classification(task, redis_params, opts, result, sel_par local payload = encode_json(ctx) local ttl = to_seconds(opts.ttl) + local expire_at = now() + ttl + + -- Log what we're storing in context + lua_util.debugm(N, task, + 'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s', + tostring(ident.identity or '(none)'), + tostring(#ctx.recent_messages), + table.concat(ctx.last_spam_labels or {}, ','), + table.concat(ctx.top_senders or {}, ','), + table.concat(ctx.flagged_phrases or {}, ','), + tostring(#payload), + os.date('%Y-%m-%d %H:%M:%S', expire_at)) + + if msg then + lua_util.debugm(N, task, + 'added message: from=%s, subject=%s, keywords=%s', + tostring(msg.from or '(none)'), + tostring(msg.subject or '(none)'), + table.concat(msg.keywords or {}, ',')) + end + local function on_set(set_err) if set_err then - rspamd_logger.errx(task, 'llm_context: set failed: %s', set_err) + rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err)) + else + lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s', + tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at)) end end local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX', -- 2.47.3