From 1a14c828c27c8296ddf099e39977fb68805d01b0 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 1 Oct 2025 10:49:41 +0100 Subject: [PATCH] [Feature] Add user/domain context support for LLM-based classification * Add llm_context.lua module for Redis-based conversation context * Context features: sliding window, top senders, keywords, flagged phrases * Use low-level word API (get_words('full')) with stop_word flags * Flexible gating via maps/selectors (enable_map/enable_expression) * Update context even when GPT condition not met (BAYES_SPAM/HAM) * Add min_messages warm-up threshold to prevent weak context injection * Configurable scope: user/domain/esld with TTL and sliding window --- .cursor/rules/commits-and-tags.mdc | 151 ++++++++++ lualib/llm_context.lua | 434 +++++++++++++++++++++++++++++ src/plugins/lua/gpt.lua | 271 +++++++++++++++--- 3 files changed, 812 insertions(+), 44 deletions(-) create mode 100644 .cursor/rules/commits-and-tags.mdc create mode 100644 lualib/llm_context.lua diff --git a/.cursor/rules/commits-and-tags.mdc b/.cursor/rules/commits-and-tags.mdc new file mode 100644 index 0000000000..35b6157be5 --- /dev/null +++ b/.cursor/rules/commits-and-tags.mdc @@ -0,0 +1,151 @@ +--- +description: "Rspamd: commit message format, tagging, and release procedures" +globs: ["**"] +alwaysApply: true +--- + +# Commit Message Format + +All commits in the Rspamd project follow a structured format with tags that indicate the type of change: + +## Commit Tags + +Use one of the following tags at the beginning of commit messages: + +- `[Feature]` - New features and capabilities +- `[Fix]` - Bug fixes and corrections +- `[CritFix]` - Critical bug fixes that need immediate attention +- `[Minor]` - Minor changes, tweaks, or version updates +- `[Project]` - Project-wide changes, refactoring, or infrastructure updates +- `[Rework]` - Major reworking of existing functionality +- `[Conf]` - Configuration changes or updates +- `[Test]` - Test additions or modifications +- `[Rules]` - Changes to spam detection rules + +## Commit Message Examples + +**Version updates:** +``` +[Minor] Update version of rspamd to X.Y.Z +``` + +**Single-line changes:** +``` +[Fix] Fix memory leak in dkim module +[Feature] Add support for encrypted maps +[Minor] Add missing cmath include +``` + +**Multi-line changes (for releases or complex changes):** +``` +Release X.Y.Z + +* [Feature] First feature description +* [Feature] Second feature description +* [Fix] First fix description +* [Fix] Second fix description +``` + +## GPG Signing Requirements + +**All commits and tags MUST be signed with GPG:** + +- Use `git commit -S` to sign commits +- Use `git tag -s ` to sign tags +- Verify signatures with `git log --show-signature` or `git tag -v ` + +## Release Process + +### 1. Update ChangeLog + +Add release notes to `ChangeLog` following the existing format: + +``` +X.Y.Z: DD MMM YYYY + * [Feature] Feature description + * [Fix] Fix description + * [Project] Project-level changes +``` + +Format rules: +- Date format: `DD MMM YYYY` (e.g., `30 Sep 2025`) +- Each entry starts with ` * [Tag]` (two spaces, asterisk, space, tag) +- Group entries by tag type +- Keep descriptions concise but informative + +### 2. Create Release Commit + +Create a commit with the full release notes: + +```bash +git add ChangeLog +git commit --no-verify -S -m "Release X.Y.Z + +* [Feature] Feature 1 +* [Feature] Feature 2 +* [Fix] Fix 1 +* [Fix] Fix 2 +..." +``` + +### 3. Create Release Tag + +Create an annotated, signed tag: + +```bash +git tag -s X.Y.Z -m "Rspamd X.Y.Z + +Brief release summary highlighting main features and fixes. + +Main features: +* Feature 1 +* Feature 2 + +Critical fixes: +* Fix 1 +* Fix 2 + +Additional context or notes about the release." +``` + +### 4. Update Version for Next Development Cycle + +After creating a release tag, update the version in `CMakeLists.txt`: + +```bash +# Edit CMakeLists.txt: increment RSPAMD_VERSION_PATCH +git add CMakeLists.txt +git commit --no-verify -S -m "[Minor] Update version of rspamd to X.Y.Z" +``` + +## Version Numbers + +Version numbers are defined in `CMakeLists.txt`: + +```cmake +set(RSPAMD_VERSION_MAJOR X) +set(RSPAMD_VERSION_MINOR Y) +set(RSPAMD_VERSION_PATCH Z) +``` + +- **MAJOR**: Incompatible API changes or major breaking changes +- **MINOR**: New features in a backward-compatible manner +- **PATCH**: Backward-compatible bug fixes + +## Pre-commit Hooks + +- If pre-commit hooks fail on unrelated issues, use `--no-verify` flag +- Only use `--no-verify` when necessary and ensure code quality manually +- Pre-commit hooks check: + - Trailing whitespace + - Line endings + - ClangFormat + - LuaCheck + +## General Guidelines + +- Write clear, descriptive commit messages +- One logical change per commit +- Reference issue numbers when applicable +- Keep commit history clean and meaningful +- Always sign commits and tags with GPG \ No newline at end of file diff --git a/lualib/llm_context.lua b/lualib/llm_context.lua new file mode 100644 index 0000000000..4e79c632a2 --- /dev/null +++ b/lualib/llm_context.lua @@ -0,0 +1,434 @@ +--[[ +Context management for LLM-based spam detection + +Provides: + - fetch(task, redis_params, opts, callback): load context JSON from Redis and format prompt snippet + - update_after_classification(task, redis_params, opts, result, sel_part): update context after LLM result + +Opts (all optional, safe defaults applied): + enabled: boolean + level: 'user' | 'domain' | 'esld' (scope for context key) + key_prefix: string (prefix before scope) + key_suffix: string (suffix after identity) + max_messages: number (sliding window size) + message_ttl: seconds + ttl: seconds (Redis key TTL) + top_senders: number (how many to keep in top_senders) + summary_max_chars: number (truncate stored text) + flagged_phrases: array of strings (case-insensitive match) + last_labels_count: number +]] + +local M = {} + +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local rspamd_util = require "rspamd_util" +local llm_common = require "llm_common" + +local EMPTY = {} + +local DEFAULTS = { + enabled = false, + level = 'user', + key_prefix = 'user', + key_suffix = 'mail_context', + max_messages = 40, + min_messages = 5, -- minimum messages in context before injecting into prompt + message_ttl = 14 * 24 * 3600, + ttl = 30 * 24 * 3600, + top_senders = 5, + summary_max_chars = 512, + flagged_phrases = { + 'reset your password', + 'click here to verify', + 'confirm your account', + 'urgent invoice', + 'wire transfer', + }, + last_labels_count = 10, +} + +local function to_seconds(v) + if type(v) == 'number' then return v end + return tonumber(v) or 0 +end + +local function get_principal_recipient(task) + return task:get_principal_recipient() +end + +local function get_domain_from_addr(addr) + if not addr then return nil end + return string.match(addr, '.*@(.+)') +end + +local function compute_identity(task, opts) + local scope = opts.level or DEFAULTS.level + local identity + if scope == 'user' then + identity = task:get_user() or get_principal_recipient(task) + if not identity then + local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'] + identity = from + end + elseif scope == 'domain' then + local rcpt = get_principal_recipient(task) + identity = get_domain_from_addr(rcpt) + if not identity then + identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] + end + elseif scope == 'esld' then + local rcpt = get_principal_recipient(task) + local d = get_domain_from_addr(rcpt) + if d then + identity = rspamd_util.get_tld(d) + end + if not identity then + local fd = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain'] + if fd then identity = rspamd_util.get_tld(fd) end + end + else + scope = 'user' + identity = task:get_user() or get_principal_recipient(task) + end + + if not identity or identity == '' then + return nil + end + + local key_prefix = opts.key_prefix or DEFAULTS.key_prefix + local key_suffix = opts.key_suffix or DEFAULTS.key_suffix + local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix) + + return { + scope = scope, + identity = identity, + key = key, + } +end + +local function parse_json(str) + if not str or str == '' then return nil end + local parser = ucl.parser() + local ok, err = parser:parse_string(str) + if not ok then return nil, err end + return parser:get_object() +end + +local function encode_json(obj) + return ucl.to_format(obj, 'json-compact', true) +end + +local function now() + return os.time() +end + +local function truncate_text(txt, limit) + if not txt then return '' end + if #txt <= limit then return txt end + return string.sub(txt, 1, limit) +end + +local function has_flag(flags, flag_name) + if type(flags) ~= 'table' then return false end + for _, f in ipairs(flags) do + if f == flag_name then return true end + end + return false +end + +local function extract_keywords(text_part, limit) + if not text_part then return {} end + local words = text_part:get_words('full') + if not words or #words == 0 then return {} end + + local counts = {} + for _, w in ipairs(words) do + local norm_word = w[2] or '' -- normalized + local flags = w[4] or {} + -- Skip stop words, too short, or non-text + if not has_flag(flags, 'stop_word') and #norm_word > 2 and has_flag(flags, 'text') then + counts[norm_word] = (counts[norm_word] or 0) + 1 + end + end + + local arr = {} + for word, cnt in pairs(counts) do + table.insert(arr, { w = word, c = cnt }) + end + table.sort(arr, function(a, b) + if a.c == b.c then return a.w < b.w end + return a.c > b.c + end) + + local res = {} + for i = 1, math.min(limit or 12, #arr) do + table.insert(res, arr[i].w) + end + return res +end + +local function safe_array(arr) + if type(arr) ~= 'table' then return {} end + return arr +end + +local function build_message_summary(task, sel_part, opts) + local model_cfg = { max_tokens = 256 } + local content_tbl + if sel_part then + local itbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens }) + content_tbl = itbl + else + content_tbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens }) + end + if type(content_tbl) ~= 'table' then + return nil + end + local txt = content_tbl.text or '' + local summary_max = opts.summary_max_chars or DEFAULTS.summary_max_chars + local msg = { + from = content_tbl.from or ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'], + subject = content_tbl.subject or '', + ts = now(), + keywords = extract_keywords(sel_part, 12), + } + if txt and #txt > 0 then + msg.text = truncate_text(txt, summary_max) + end + return msg +end + +local function trim_messages(recent_messages, max_messages, min_ts) + local res = {} + for _, m in ipairs(recent_messages) do + if not min_ts or (m.ts and m.ts >= min_ts) then + table.insert(res, m) + end + end + table.sort(res, function(a, b) + local ta = a.ts or 0 + local tb = b.ts or 0 + return ta > tb + end) + while #res > max_messages do + table.remove(res) + end + return res +end + +local function recompute_top_senders(sender_counts, limit_n) + local arr = {} + for s, c in pairs(sender_counts or {}) do + table.insert(arr, { s = s, c = c }) + end + table.sort(arr, function(a, b) + if a.c == b.c then return a.s < b.s end + return a.c > b.c + end) + local res = {} + for i = 1, math.min(limit_n, #arr) do + table.insert(res, arr[i].s) + end + return res +end + +local function ensure_defaults(ctx) + if type(ctx) ~= 'table' then ctx = {} end + ctx.recent_messages = safe_array(ctx.recent_messages) + ctx.top_senders = safe_array(ctx.top_senders) + ctx.flagged_phrases = safe_array(ctx.flagged_phrases) + ctx.last_spam_labels = safe_array(ctx.last_spam_labels) + ctx.sender_counts = ctx.sender_counts or {} + return ctx +end + +local function contains_ci(haystack, needle) + if not haystack or not needle then return false end + return string.find(string.lower(haystack), string.lower(needle), 1, true) ~= nil +end + +local function update_flagged_phrases(ctx, text_part, opts) + local phrases = opts.flagged_phrases or DEFAULTS.flagged_phrases + if not text_part then return end + local words = text_part:get_words('norm') + if not words or #words == 0 then return end + local text_lower = table.concat(words, ' ') + for _, p in ipairs(phrases) do + if contains_ci(text_lower, p) then + local present = false + for _, e in ipairs(ctx.flagged_phrases) do + if string.lower(e) == string.lower(p) then + present = true + break + end + end + if not present then + table.insert(ctx.flagged_phrases, p) + end + end + end +end + +local function to_bullets_recent(recent_messages, limit_n) + local lines = {} + local n = math.min(limit_n, #recent_messages) + for i = 1, n do + local m = recent_messages[i] + local from = m.from or m.sender or '' + local subj = m.subject or '' + table.insert(lines, string.format('- %s: %s', from, subj)) + end + return table.concat(lines, '\n') +end + +local function join_list(arr) + if not arr or #arr == 0 then return '' end + return table.concat(arr, ', ') +end + +local function format_context_prompt(ctx) + local bullets = to_bullets_recent(ctx.recent_messages or {}, 5) + local top_senders = join_list(ctx.top_senders or {}) + local flagged = join_list(ctx.flagged_phrases or {}) + local spam_types = join_list(ctx.last_spam_labels or {}) + + local parts = {} + table.insert(parts, 'User recent correspondence summary:') + if bullets ~= '' then + table.insert(parts, bullets) + else + table.insert(parts, '- (no recent messages)') + end + table.insert(parts, string.format('Top senders in mailbox: %s', top_senders)) + if flagged ~= '' then + table.insert(parts, string.format('Recently flagged suspicious phrases: %s', flagged)) + end + if spam_types ~= '' then + table.insert(parts, string.format('Last detected spam types: %s', spam_types)) + end + + return table.concat(parts, '\n') +end + +function M.fetch(task, redis_params, opts, callback) + opts = lua_util.override_defaults(DEFAULTS, opts or {}) + if not opts.enabled then + callback(nil, nil, nil) + return + end + if not redis_params then + callback('no redis', nil, nil) + return + end + + local ident = compute_identity(task, opts) + if not ident then + callback('no identity', nil, nil) + return + end + + local function on_get(err, data) + if err then + rspamd_logger.errx(task, 'llm_context: get failed: %s', err) + callback(err, nil, nil) + return + end + local ctx + if data then + ctx = ensure_defaults(select(1, parse_json(data)) or {}) + else + ctx = ensure_defaults({}) + end + + -- Check if context has enough messages for warm-up + local min_msgs = opts.min_messages or DEFAULTS.min_messages + local msg_count = #(ctx.recent_messages or {}) + if msg_count < min_msgs then + lua_util.debugm('llm_context', task, 'context has only %s messages (min: %s), not injecting into prompt', + msg_count, min_msgs) + callback(nil, ctx, nil) -- return ctx but no prompt snippet + return + end + + local prompt_snippet = format_context_prompt(ctx) + callback(nil, ctx, prompt_snippet) + end + + local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key }) + if not ok then + callback('request not scheduled', nil, nil) + end +end + +function M.update_after_classification(task, redis_params, opts, result, sel_part) + opts = lua_util.override_defaults(DEFAULTS, opts or {}) + if not opts.enabled then return end + if not redis_params then return end + + local ident = compute_identity(task, opts) + if not ident then return end + + local function on_get(err, data) + if err then + rspamd_logger.errx(task, 'llm_context: get for update failed: %s', err) + return + end + local ctx = ensure_defaults(select(1, parse_json(data)) or {}) + + local msg = build_message_summary(task, sel_part, opts) + if msg then + table.insert(ctx.recent_messages, 1, msg) + local sender = msg.from or '' + if sender ~= '' then + ctx.sender_counts[sender] = (ctx.sender_counts[sender] or 0) + 1 + end + update_flagged_phrases(ctx, sel_part, opts) + end + + local min_ts = now() - to_seconds(opts.message_ttl) + ctx.recent_messages = trim_messages(ctx.recent_messages, opts.max_messages, min_ts) + ctx.top_senders = recompute_top_senders(ctx.sender_counts, opts.top_senders) + + local labels = {} + if result then + if result.categories and type(result.categories) == 'table' then + for _, c in ipairs(result.categories) do table.insert(labels, tostring(c)) end + end + if result.probability then + if result.probability > 0.5 then + table.insert(labels, 'spam') + else + table.insert(labels, 'ham') + end + end + end + for _, l in ipairs(labels) do table.insert(ctx.last_spam_labels, 1, l) end + while #ctx.last_spam_labels > opts.last_labels_count do table.remove(ctx.last_spam_labels) end + + ctx.updated_at = now() + + local payload = encode_json(ctx) + local ttl = to_seconds(opts.ttl) + local function on_set(set_err) + if set_err then + rspamd_logger.errx(task, 'llm_context: set failed: %s', set_err) + end + end + local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX', + { ident.key, tostring(ttl), payload }) + if not ok then + rspamd_logger.errx(task, 'llm_context: set request was not scheduled') + end + end + + local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key }) + if not ok then + rspamd_logger.errx(task, 'llm_context: initial get request was not scheduled') + end +end + +return M diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 0989657995..394923ed7c 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -62,6 +62,26 @@ if confighelp then reason_header = "X-GPT-Reason"; # Use JSON format for response json = false; + + # Optional user/domain context in Redis + context = { + enabled = true; # fetch and inject user/domain conversation context + # scope level for identity: user | domain | esld + level = "user"; + # redis key structure: :: + key_prefix = "user"; + key_suffix = "mail_context"; + # sliding window and TTLs + max_messages = 40; # keep up to N compact message summaries + min_messages = 5; # warm-up: inject context only after N messages collected + message_ttl = 14d; # forget messages older than this when recomputing + ttl = 30d; # Redis key TTL + top_senders = 5; # track top senders + summary_max_chars = 512; # compress body to this size for storage + flagged_phrases = ["reset your password", "click here to verify"]; # optional list + last_labels_count = 10; # keep last N labels + as_system = true; # place context snippet as additional system message + }; } ]]) return @@ -76,6 +96,10 @@ local lua_redis = require "lua_redis" local ucl = require "ucl" -- local fun = require "fun" -- no longer needed after llm_common usage local lua_cache = require "lua_cache" +local llm_context = require "llm_context" +local lua_maps_expressions = require "lua_maps_expressions" +local lua_maps = require "lua_maps" +local lua_selectors = require "lua_selectors" -- Exclude checks if one of those is found local default_symbols_to_except = { @@ -148,9 +172,90 @@ local settings = { json = false, extra_symbols = nil, cache_prefix = REDIS_PREFIX, + -- user/domain context options (nested table forwarded to llm_context) + context = { + enabled = false, + level = 'user', -- 'user' | 'domain' | 'esld' + key_prefix = 'user', + key_suffix = 'mail_context', + max_messages = 40, + min_messages = 5, -- warm-up threshold: minimum messages before injecting context into prompt + message_ttl = 1209600, -- 14d + ttl = 2592000, -- 30d + top_senders = 5, + summary_max_chars = 512, + flagged_phrases = { 'reset your password', 'click here to verify' }, + last_labels_count = 10, + as_system = true, -- inject context snippet as system message; false => user message + -- Optional gating using selectors and maps to enable/disable context dynamically + -- One can use either a simple enable_map or a full maps expression + -- Example enable_map: + -- enable_map = { selector = "esld_principal_recipient_domain", map = "/etc/rspamd/context-enabled-domains.map", type = "set" } + enable_map = nil, + -- Example enable_expression: + -- enable_expression = { + -- rules = { + -- dom = { selector = "esld_principal_recipient_domain", map = "/etc/rspamd/context-enabled-domains.map" }, + -- user = { selector = "user", map = "/etc/rspamd/context-enabled-users.map" }, + -- }, + -- expression = "dom | user" + -- } + enable_expression = nil, + -- Optional negative gating + disable_expression = nil, + }, } local redis_params local cache_context +local compiled_context_gating = { + enable_expr = nil, + disable_expr = nil, + enable_map = nil, -- { selector_fn, map } +} + +local function is_context_enabled_for_task(task) + local ctx = settings.context + if not ctx then return false end + + local enabled = ctx.enabled or false + + -- Positive gating via expression + if compiled_context_gating.enable_expr then + local res = compiled_context_gating.enable_expr:process(task) + if res then + enabled = true + end + end + + -- Positive gating via simple map + if compiled_context_gating.enable_map then + local vals = compiled_context_gating.enable_map.selector_fn(task) + local matched = false + if type(vals) == 'table' then + for _, v in ipairs(vals) do + if compiled_context_gating.enable_map.map:get_key(v) then + matched = true + break + end + end + elseif vals then + matched = compiled_context_gating.enable_map.map:get_key(vals) and true or false + end + if matched then + enabled = true + end + end + + -- Negative gating + if enabled and compiled_context_gating.disable_expr then + local res = compiled_context_gating.disable_expr:process(task) + if res then + enabled = false + end + end + + return enabled +end local function default_condition(task) -- Check result @@ -561,6 +666,11 @@ local function insert_results(task, result, sel_part) if cache_context then lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context) end + + -- Update long-term user/domain context after classification + if redis_params and settings.context then + llm_context.update_after_classification(task, redis_params, settings.context, result, sel_part) + end end local function check_consensus_and_insert_results(task, results, sel_part) @@ -595,20 +705,22 @@ local function check_consensus_and_insert_results(task, results, sel_part) end lua_util.shuffle(reasons) - local reason = reasons[1] or nil + local reason_obj = reasons[1] + local reason_text = reason_obj and reason_obj.reason or nil + local reason_categories = reason_obj and reason_obj.categories or nil if nspam > nham and max_spam_prob > 0.75 then insert_results(task, { probability = max_spam_prob, - reason = reason.reason, - categories = reason.categories, + reason = reason_text, + categories = reason_categories, }, sel_part) elseif nham > nspam and max_ham_prob < 0.25 then insert_results(task, { probability = max_ham_prob, - reason = reason.reason, - categories = reason.categories, + reason = reason_text, + categories = reason_categories, }, sel_part) else @@ -619,15 +731,15 @@ end -- get_meta_llm_content moved to llm_common -local function check_llm_uncached(task, content, sel_part) - return settings.specific_check(task, content, sel_part) +local function check_llm_uncached(task, content, sel_part, context_snippet) + return settings.specific_check(task, content, sel_part, context_snippet) end -local function check_llm_cached(task, content, sel_part) +local function check_llm_cached(task, content, sel_part, context_snippet) local cache_key = redis_cache_key(sel_part) lua_cache.cache_get(task, cache_key, cache_context, settings.timeout * 1.5, function() - check_llm_uncached(task, content, sel_part) + check_llm_uncached(task, content, sel_part, context_snippet) end, function(_, err, data) if err then rspamd_logger.errx(task, 'cannot get cache: %s', err) @@ -638,12 +750,12 @@ local function check_llm_cached(task, content, sel_part) rspamd_logger.infox(task, 'found cached response %s', cache_key) insert_results(task, data, sel_part) else - check_llm_uncached(task, content, sel_part) + check_llm_uncached(task, content, sel_part, context_snippet) end end) end -local function openai_check(task, content, sel_part) +local function openai_check(task, content, sel_part, context_snippet) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -684,7 +796,7 @@ local function openai_check(task, content, sel_part) end end - -- Build messages exactly as in the original code if structured table provided + -- Build messages with optional user/domain context local user_messages if type(content) == 'table' then local subject_line = 'Subject: ' .. (content.subject or '') @@ -700,22 +812,25 @@ local function openai_check(task, content, sel_part) } end + local sys_messages = { + { role = 'system', content = settings.prompt } + } + if context_snippet and settings.context and settings.context.as_system ~= false then + table.insert(sys_messages, { role = 'system', content = context_snippet }) + elseif context_snippet and settings.context and settings.context.as_system == false then + table.insert(user_messages, 1, { role = 'user', content = context_snippet }) + end + local body_base = { stream = false, - messages = { - { - role = 'system', - content = settings.prompt - }, - lua_util.unpack(user_messages) - } + messages = {} } + for _, m in ipairs(sys_messages) do table.insert(body_base.messages, m) end + for _, m in ipairs(user_messages) do table.insert(body_base.messages, m) end - if type(settings.model) == 'string' then - settings.model = { settings.model } - end + local models_list = type(settings.model) == 'string' and { settings.model } or settings.model - for idx, model in ipairs(settings.model) do + for idx, model in ipairs(models_list) do results[idx] = { success = false, checked = false @@ -766,7 +881,7 @@ local function openai_check(task, content, sel_part) end end -local function ollama_check(task, content, sel_part) +local function ollama_check(task, content, sel_part, context_snippet) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -821,26 +936,25 @@ local function ollama_check(task, content, sel_part) } end - if type(settings.model) == 'string' then - settings.model = { settings.model } + local models_list = type(settings.model) == 'string' and { settings.model } or settings.model + + local sys_messages = { + { role = 'system', content = settings.prompt } + } + if context_snippet and settings.context and settings.context.as_system ~= false then + table.insert(sys_messages, { role = 'system', content = context_snippet }) + elseif context_snippet and settings.context and settings.context.as_system == false then + table.insert(user_messages, 1, { role = 'user', content = context_snippet }) end local body_base = { stream = false, - model = settings.model, - -- should not in body_base - -- max_tokens = settings.max_tokens, - -- temperature = settings.temperature, - messages = { - { - role = 'system', - content = settings.prompt - }, - table.unpack(user_messages) - } + messages = {} } + for _, m in ipairs(sys_messages) do table.insert(body_base.messages, m) end + for _, m in ipairs(user_messages) do table.insert(body_base.messages, m) end - for idx, model in ipairs(settings.model) do + for idx, model in ipairs(models_list) do results[idx] = { success = false, checked = false @@ -891,6 +1005,31 @@ end local function gpt_check(task) local ret, content, sel_part = settings.condition(task) + -- Always update context if enabled, even when condition is not met + local context_enabled = redis_params and settings.context and is_context_enabled_for_task(task) + if context_enabled and not ret then + -- Condition not met (e.g. BAYES_SPAM, passthrough, etc.) + -- Update context without LLM call; infer result from task metrics + if not sel_part then + -- Try to get text part for context update + sel_part = lua_mime.get_displayed_text_part(task) + end + if sel_part then + local result = task:get_metric_result() + local inferred_result = nil + if result then + if result.action == 'reject' or (result.score and result.score > 10) then + inferred_result = { probability = 0.9, reason = 'rejected by filters', categories = {} } + elseif result.action == 'no action' and result.score and result.score < 0 then + inferred_result = { probability = 0.1, reason = 'ham by filters', categories = {} } + end + end + llm_context.update_after_classification(task, redis_params, settings.context, inferred_result, sel_part) + end + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s; context updated", content) + return + end + if not ret then rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) return @@ -901,11 +1040,21 @@ local function gpt_check(task) return end - if sel_part then - -- Check digest - check_llm_cached(task, content, sel_part) + local function proceed(context_snippet) + if sel_part then + -- Check digest + check_llm_cached(task, content, sel_part, context_snippet) + else + check_llm_uncached(task, content, nil, context_snippet) + end + end + + if context_enabled then + llm_context.fetch(task, redis_params, settings.context, function(_, _, snippet) + proceed(snippet) + end) else - check_llm_uncached(task, content) + proceed(nil) end end @@ -1018,8 +1167,8 @@ if opts then "Output ONLY 3 lines:\n" .. "1. Numeric score (0.00-1.00)\n" .. "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" .. - "3. Empty line or mention ONLY the primary concern category if found from the list: " .. - table.concat(lua_util.keys(categories_map), ', ') + "3. Empty line or mention ONLY the primary concern category if found from the list: " .. + table.concat(lua_util.keys(categories_map), ', ') else settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. "FROM and url domains. Evaluate spam probability (0-1). " .. @@ -1028,4 +1177,38 @@ if opts then "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" end end + + -- Compile optional context gating + if settings.context then + local ctx = settings.context + if ctx.enable_expression then + local expr = lua_maps_expressions.create(rspamd_config, ctx.enable_expression, N .. "/context-enable") + if expr then + compiled_context_gating.enable_expr = expr + else + rspamd_logger.warnx(rspamd_config, 'failed to compile context enable_expression') + end + end + if ctx.disable_expression then + local expr = lua_maps_expressions.create(rspamd_config, ctx.disable_expression, N .. "/context-disable") + if expr then + compiled_context_gating.disable_expr = expr + else + rspamd_logger.warnx(rspamd_config, 'failed to compile context disable_expression') + end + end + if ctx.enable_map and type(ctx.enable_map) == 'table' and ctx.enable_map.selector and ctx.enable_map.map then + local sel = lua_selectors.create_selector_closure(rspamd_config, ctx.enable_map.selector) + local map = lua_maps.map_add_from_ucl(ctx.enable_map.map, ctx.enable_map.type or 'set', + 'GPT context enable map') + if sel and map then + compiled_context_gating.enable_map = { + selector_fn = sel, + map = map, + } + else + rspamd_logger.warnx(rspamd_config, 'failed to compile context enable_map: selector or map invalid') + end + end + end end -- 2.47.3