From: Vsevolod Stakhov Date: Thu, 28 Aug 2025 15:32:16 +0000 (+0100) Subject: [Project] Add tests for LLM provider, fix various issues with metatokens X-Git-Tag: 3.13.0~22^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=69db7c992f5eade0c25ced44067410ed57a6a4b9;p=thirdparty%2Frspamd.git [Project] Add tests for LLM provider, fix various issues with metatokens --- diff --git a/lualib/lua_cache.lua b/lualib/lua_cache.lua index c87a9dc78d..5fb1fbbe79 100644 --- a/lualib/lua_cache.lua +++ b/lualib/lua_cache.lua @@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -]]-- +]] -- --[[[ -- @module lua_cache @@ -82,10 +82,10 @@ local exports = {} -- Default options local default_opts = { cache_prefix = "rspamd_cache", - cache_ttl = 3600, -- 1 hour - cache_probes = 5, -- Number of times to check a pending key - cache_format = "json", -- Serialization format - cache_hash_len = 16, -- Number of hex symbols to use for hashed keys + cache_ttl = 3600, -- 1 hour + cache_probes = 5, -- Number of times to check a pending key + cache_format = "json", -- Serialization format + cache_hash_len = 16, -- Number of hex symbols to use for hashed keys cache_use_hashing = false -- Whether to hash keys by default } @@ -110,8 +110,9 @@ local function get_cache_key(raw_key, cache_context, force_hashing) end if should_hash then - lua_util.debugm(N, rspamd_config, "hashing key '%s' with hash length %s", - raw_key, cache_context.opts.cache_hash_len) + local raw_len = (type(raw_key) == 'string') and #raw_key or -1 + lua_util.debugm(N, rspamd_config, "hashing cache key (len=%s) with hash length %s", + raw_len, cache_context.opts.cache_hash_len) return hash_key(raw_key, cache_context.opts.cache_hash_len) else return raw_key @@ -133,8 +134,8 @@ local function create_cache_context(redis_params, opts, module_name) -- Register Redis prefix lua_redis.register_prefix(cache_context.opts.cache_prefix, - "caching", - "Cache API prefix") + "caching", + "Cache API prefix") lua_util.debugm(N, rspamd_config, "registered redis prefix: %s", cache_context.opts.cache_prefix) @@ -233,7 +234,7 @@ local function create_pending_marker(timeout, cache_context) } lua_util.debugm(cache_context.N, rspamd_config, "creating PENDING marker for host %s, timeout %s", - hostname, timeout) + hostname, timeout) return "PENDING:" .. encode_data(pending_data, cache_context) end @@ -245,8 +246,8 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c return false end - local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false) - lua_util.debugm(cache_context.N, task, "cache lookup for key: %s (%s)", key, full_key) + local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil) + lua_util.debugm(cache_context.N, task, "cache lookup for key: %s", full_key) -- Function to check a pending key local function check_pending(pending_info) @@ -254,13 +255,13 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c local probe_interval = timeout / (cache_context.opts.cache_probes or 5) lua_util.debugm(cache_context.N, task, "setting up probes for pending key %s, interval: %s seconds", - full_key, probe_interval) + full_key, probe_interval) -- Set up a timer to probe the key local function probe_key() probe_count = probe_count + 1 lua_util.debugm(cache_context.N, task, "probe #%s/%s for pending key %s", - probe_count, cache_context.opts.cache_probes, full_key) + probe_count, cache_context.opts.cache_probes, full_key) if probe_count >= cache_context.opts.cache_probes then logger.infox(task, "maximum probes reached for key %s, considering it failed", full_key) @@ -271,102 +272,102 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c lua_util.debugm(cache_context.N, task, "probing redis for key %s", full_key) lua_redis.redis_make_request(task, cache_context.redis_params, key, false, - function(err, data) - if err then - logger.errx(task, "redis error while probing key %s: %s", full_key, err) - lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err) - task:add_timer(probe_interval, probe_key) - return - end + function(err, data) + if err then + logger.errx(task, "redis error while probing key %s: %s", full_key, err) + lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err) + task:add_timer(probe_interval, probe_key) + return + end - if not data or type(data) == 'userdata' then - lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key) - callback_uncached(task) - return - end + if not data or type(data) == 'userdata' then + lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key) + callback_uncached(task) + return + end - local pending = parse_pending_value(data, cache_context) - if pending then - lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later", - full_key, pending.hostname) - task:add_timer(probe_interval, probe_key) - else - lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key) - callback_data(task, nil, decode_data(data, cache_context)) - end - end, - 'GET', { full_key } + local pending = parse_pending_value(data, cache_context) + if pending then + lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later", + full_key, pending.hostname) + task:add_timer(probe_interval, probe_key) + else + lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key) + callback_data(task, nil, decode_data(data, cache_context)) + end + end, + 'GET', { full_key } ) end -- Start the first probe after the initial probe interval lua_util.debugm(cache_context.N, task, "scheduling first probe for %s in %s seconds", - full_key, probe_interval) + full_key, probe_interval) task:add_timer(probe_interval, probe_key) end -- Initial cache lookup lua_util.debugm(cache_context.N, task, "making initial redis GET request for key: %s", full_key) lua_redis.redis_make_request(task, cache_context.redis_params, key, false, - function(err, data) - if err then - logger.errx(task, "redis error looking up key %s: %s", full_key, err) - lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err) - callback_uncached(task) - return - end - - if not data or type(data) == 'userdata' then - -- Key not found, set pending and call the uncached callback - lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key) - local pending_marker = create_pending_marker(timeout, cache_context) + function(err, data) + if err then + logger.errx(task, "redis error looking up key %s: %s", full_key, err) + lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err) + callback_uncached(task) + return + end - lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s", - full_key, timeout * 2) + if not data or type(data) == 'userdata' then + -- Key not found, set pending and call the uncached callback + lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key) + local pending_marker = create_pending_marker(timeout, cache_context) + + lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s", + full_key, timeout * 2) + lua_redis.redis_make_request(task, cache_context.redis_params, key, true, + function(set_err, set_data) + if set_err then + logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err) + lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err) + else + lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key) + end + lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key) + callback_uncached(task) + end, + 'SETEX', { full_key, tostring(timeout * 2), pending_marker } + ) + else + -- Key found, check if it's a pending marker or actual data + local pending = parse_pending_value(data, cache_context) + + if pending then + -- Key is being processed by another worker + lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result", + full_key, pending.hostname) + check_pending(pending) + else + -- Extend TTL and return data + lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s", + full_key, cache_context.opts.cache_ttl) lua_redis.redis_make_request(task, cache_context.redis_params, key, true, - function(set_err, set_data) - if set_err then - logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err) - lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err) - else - lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key) - end - lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key) - callback_uncached(task) - end, - 'SETEX', { full_key, tostring(timeout * 2), pending_marker } + function(expire_err, _) + if expire_err then + logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err) + lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err) + else + lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key) + end + end, + 'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) } ) - else - -- Key found, check if it's a pending marker or actual data - local pending = parse_pending_value(data, cache_context) - if pending then - -- Key is being processed by another worker - lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result", - full_key, pending.hostname) - check_pending(pending) - else - -- Extend TTL and return data - lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s", - full_key, cache_context.opts.cache_ttl) - lua_redis.redis_make_request(task, cache_context.redis_params, key, true, - function(expire_err, _) - if expire_err then - logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err) - lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err) - else - lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key) - end - end, - 'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) } - ) - - lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key) - callback_data(task, nil, decode_data(data, cache_context)) - end + lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key) + callback_data(task, nil, decode_data(data, cache_context)) end - end, - 'GET', { full_key } + end + end, + 'GET', { full_key } ) return true @@ -379,24 +380,24 @@ local function cache_set(task, key, data, cache_context) return false end - local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false) - lua_util.debugm(cache_context.N, task, "caching data for key: %s (%s) with TTL: %s", - full_key, key, cache_context.opts.cache_ttl) + local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil) + lua_util.debugm(cache_context.N, task, "caching data for key: %s with TTL: %s", + full_key, cache_context.opts.cache_ttl) local encoded_data = encode_data(data, cache_context) -- Store the data with expiration lua_util.debugm(cache_context.N, task, "making redis SETEX request for key: %s", full_key) return lua_redis.redis_make_request(task, cache_context.redis_params, key, true, - function(err, result) - if err then - logger.errx(task, "redis error setting cached data for %s: %s", full_key, err) - lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err) - else - lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key) - end - end, - 'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data } + function(err, result) + if err then + logger.errx(task, "redis error setting cached data for %s: %s", full_key, err) + lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err) + else + lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key) + end + end, + 'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data } ) end @@ -407,21 +408,21 @@ local function cache_del(task, key, cache_context) return false end - local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false) + local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil) lua_util.debugm(cache_context.N, task, "deleting cache key: %s", full_key) return lua_redis.redis_make_request(task, cache_context.redis_params, key, true, - function(err, result) - if err then - logger.errx(task, "redis error deleting cache key %s: %s", full_key, err) - lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err) - else - local count = tonumber(result) or 0 - lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)", - full_key, count) - end - end, - 'DEL', { full_key } + function(err, result) + if err then + logger.errx(task, "redis error deleting cache key %s: %s", full_key, err) + lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err) + else + local count = tonumber(result) or 0 + lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)", + full_key, count) + end + end, + 'DEL', { full_key } ) end diff --git a/lualib/lua_meta.lua b/lualib/lua_meta.lua index 340d89ee81..de006df8e7 100644 --- a/lualib/lua_meta.lua +++ b/lualib/lua_meta.lua @@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -]]-- +]] -- local exports = {} @@ -87,7 +87,7 @@ local function meta_images_function(task) nlarge = 1.0 * nlarge / ntotal nsmall = 1.0 * nsmall / ntotal end - return { ntotal, njpg, npng, nlarge, nsmall } + return { ntotal, npng, njpg, nlarge, nsmall } -- Fixed order to match names end local function meta_nparts_function(task) @@ -164,29 +164,28 @@ local function meta_received_function(task) local fun = require "fun" if rh and #rh > 0 then - local ntotal = 0.0 local init_time = 0 fun.each(function(rc) - ntotal = ntotal + 1.0 + ntotal = ntotal + 1.0 - if not rc.by_hostname then - invalid_factor = invalid_factor + 1.0 - end - if init_time == 0 and rc.timestamp then - init_time = rc.timestamp - elseif rc.timestamp then - time_factor = time_factor + math.abs(init_time - rc.timestamp) - init_time = rc.timestamp - end - if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then - secure_factor = secure_factor + 1.0 - end - end, - fun.filter(function(rc) - return not rc.flags or not rc.flags['artificial'] - end, rh)) + if not rc.by_hostname then + invalid_factor = invalid_factor + 1.0 + end + if init_time == 0 and rc.timestamp then + init_time = rc.timestamp + elseif rc.timestamp then + time_factor = time_factor + math.abs(init_time - rc.timestamp) + init_time = rc.timestamp + end + if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then + secure_factor = secure_factor + 1.0 + end + end, + fun.filter(function(rc) + return not rc.flags or not rc.flags['artificial'] + end, rh)) if ntotal > 0 then invalid_factor = invalid_factor / ntotal @@ -263,8 +262,8 @@ local function meta_words_function(task) end local ret = { - short_words, - ret_len, + ret_len, -- avg_words_len (moved to match the names array) + short_words, -- nshort_words } local divisor = 1.0 @@ -460,10 +459,10 @@ local function rspamd_gen_metatokens(task, names) local ct = mt.cb(task) for i, tok in ipairs(ct) do lua_util.debugm(N, task, "metatoken: %s = %s", - mt.names[i], tok) + mt.names[i], tok) if tok ~= tok or tok == math.huge then logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', - mt.names[i], tok) + mt.names[i], tok) tok = 0.0 end table.insert(metatokens, tok) @@ -472,14 +471,13 @@ local function rspamd_gen_metatokens(task, names) task:cache_set('metatokens', metatokens) end - else for _, n in ipairs(names) do if metatokens_by_name[n] then local tok = metatokens_by_name[n](task) if tok ~= tok or tok == math.huge then logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', - n, tok) + n, tok) tok = 0.0 end table.insert(metatokens, tok) @@ -503,7 +501,7 @@ local function rspamd_gen_metatokens_table(task) for i, tok in ipairs(ct) do if tok ~= tok or tok == math.huge then logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', - mt.names[i], tok) + mt.names[i], tok) tok = 0.0 end diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 17661c1e94..5fcb75fcf9 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -139,6 +139,28 @@ register_provider('symbols', { end, }) +-- Metatokens-only provider for contexts where symbols are not available +register_provider('metatokens', { + collect = function(task, ctx) + local mt = meta_functions.rspamd_gen_metatokens(task) + -- Convert to table of numbers + local vec = {} + for i = 1, #mt do + vec[i] = tonumber(mt[i]) or 0.0 + end + return vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 } + end, + collect_async = function(task, ctx, cont) + local mt = meta_functions.rspamd_gen_metatokens(task) + -- Convert to table of numbers + local vec = {} + for i = 1, #mt do + vec[i] = tonumber(mt[i]) or 0.0 + end + cont(vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 }) + end, +}) + local function load_scripts() redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, redis_params) @@ -546,6 +568,7 @@ end local function redis_ann_prefix(rule, settings_name) -- We also need to count metatokens: + -- Note: meta_functions.version represents the metatoken format version local n = meta_functions.version return string.format('%s%d_%s_%d_%s', settings.prefix, plugin_ver, rule.prefix, n, settings_name) @@ -669,6 +692,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb) end prov.collect_async(task, { profile = profile_or_set, + set = profile_or_set, rule = rule, config = pcfg, weight = pcfg.weight or 1.0, @@ -682,25 +706,49 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb) end) end - -- Include metatokens as an extra provider when configured - local include_meta = rule.fusion and rule.fusion.include_meta + -- Include symbols provider (which includes both symbols AND metatokens) as an extra provider + -- The name 'include_meta' is historical but it actually includes the full symbols provider + -- For backward compatibility, include symbols by default unless explicitly disabled + local include_meta = false + if not providers_cfg or #providers_cfg == 0 then + -- No providers, always use symbols (which includes metatokens) + include_meta = true + elseif rule.fusion then + -- Explicit fusion config takes precedence + include_meta = rule.fusion.include_meta + if include_meta == nil then + -- Default to true for backward compatibility when fusion is configured but include_meta not specified + include_meta = true + end + else + -- Providers configured but no fusion settings - default to including symbols+metatokens + include_meta = true + end + local meta_weight = (rule.fusion and rule.fusion.meta_weight) or 1.0 remaining = #providers_cfg + (include_meta and 1 or 0) + + -- Start all configured providers for i, pcfg in ipairs(providers_cfg) do start_provider(i, pcfg) end if include_meta then - local prov = get_provider('symbols') + -- Always use metatokens provider for consistency + -- This ensures same dimensions whether called from controller or full scan + local prov = get_provider('metatokens') + if prov and prov.collect_async then - prov.collect_async(task, { profile = profile_or_set, weight = meta_weight, phase = phase }, function(vec, meta) - if vec then - metas[#metas + 1] = { name = 'symbols', type = 'symbols', dim = #vec, weight = meta_weight } - vectors[#vectors + 1] = vec - end - maybe_finish() - end) + local meta_index = #providers_cfg + 1 -- Metatokens always come after providers + prov.collect_async(task, { profile = profile_or_set, set = profile_or_set, weight = meta_weight, phase = phase }, + function(vec, meta) + if vec then + metas[meta_index] = meta + vectors[meta_index] = vec + end + maybe_finish() + end) else maybe_finish() end @@ -711,8 +759,24 @@ end local function spawn_train(params) -- Check training data sanity -- Now we need to join inputs and create the appropriate test vectors - local n = #params.set.symbols + - meta_functions.rspamd_count_metatokens() + local n + + -- When using providers, derive dimension from actual vectors + if params.rule.providers and #params.rule.providers > 0 and + (#params.spam_vec > 0 or #params.ham_vec > 0) then + -- Use dimension from stored vectors + if #params.spam_vec > 0 then + n = #params.spam_vec[1] + else + n = #params.ham_vec[1] + end + lua_util.debugm(N, rspamd_config, 'spawn_train: using vector dimension %s from stored vectors', n) + else + -- Traditional symbol-based dimension + n = #params.set.symbols + meta_functions.rspamd_count_metatokens() + lua_util.debugm(N, rspamd_config, 'spawn_train: using symbol dimension %s symbols + %s metatokens = %s', + #params.set.symbols, meta_functions.rspamd_count_metatokens(), n) + end -- Now we can train ann local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule) @@ -1148,7 +1212,7 @@ result_to_vector = function(task, profile) if not profile.zeros then -- Fill zeros vector local zeros = {} - for i = 1, meta_functions.count_metatokens() do + for i = 1, meta_functions.rspamd_count_metatokens() do zeros[i] = 0.0 end for _, _ in ipairs(profile.symbols) do diff --git a/lualib/plugins/neural/providers/llm.lua b/lualib/plugins/neural/providers/llm.lua index 8f08fbb57b..1bc1063aae 100644 --- a/lualib/plugins/neural/providers/llm.lua +++ b/lualib/plugins/neural/providers/llm.lua @@ -74,6 +74,16 @@ neural_common.register_provider('llm', { return end + -- Do not run embeddings on infer if ANN is not loaded for this set/profile + if ctx.phase == 'infer' then + local set_or_profile = ctx.profile or ctx.set + if not set_or_profile or not set_or_profile.ann then + rspamd_logger.debugm(N, task, 'skip llm on infer: ANN not loaded for current settings') + cont(nil) + return + end + end + local input_tbl = select_text(task) if not input_tbl then rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip') diff --git a/rules/controller/neural.lua b/rules/controller/neural.lua index 0aace1cc1d..13530beffe 100644 --- a/rules/controller/neural.lua +++ b/rules/controller/neural.lua @@ -195,22 +195,37 @@ local function handle_learn_message(task, conn) return end - -- If no providers or symbols provider configured, require full scan path + -- Check if this configuration requires full scan + -- Only symbols collection requires full scan; metatokens can be computed directly local has_providers = type(rule.providers) == 'table' and #rule.providers > 0 - if not has_providers then - lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s', + + if not has_providers and not rule.disable_symbols_input then + -- No providers means full symbols will be used (not just metatokens) + lua_util.debugm(N, task, + 'controller.neural: learn_message refused: no providers configured, symbols collection requires full scan for rule=%s', rule_name) - conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)') + conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured, full symbols collection required)') return end - for _, p in ipairs(rule.providers) do - if p.type == 'symbols' then - lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name) - conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)') - return + + -- Check if any provider requires full scan (only symbols provider does) + if has_providers then + for _, p in ipairs(rule.providers) do + if p.type == 'symbols' then + lua_util.debugm(N, task, + 'controller.neural: learn_message refused due to symbols provider requiring full scan for rule=%s', + rule_name) + conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)') + return + end end end + -- At this point: + -- - We have providers that don't require full scan (e.g., LLM) + -- - Metatokens can be computed directly from the message + -- - Controller training is allowed + local set = neural_common.get_rule_settings(task, rule) if not set then lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set', @@ -224,6 +239,11 @@ local function handle_learn_message(task, conn) end end + if set then + lua_util.debugm(N, task, 'controller.neural: set found for rule=%s, symbols=%s, name=%s', + rule_name, set.symbols and #set.symbols or "nil", set.name) + end + -- Derive redis base key even if ANN not yet initialized local redis_base if set and set.ann and set.ann.redis_key then @@ -244,17 +264,55 @@ local function handle_learn_message(task, conn) return end + -- Ensure profile exists for this set + if not set.ann then + local version = 0 + local ann_key = neural_common.new_ann_key(rule, set, version) + + local profile = { + symbols = set.symbols, + redis_key = ann_key, + version = version, + digest = set.digest, + distance = 0, + providers_digest = neural_common.providers_config_digest(rule.providers), + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + lua_util.debugm(N, task, 'controller.neural: creating new profile for %s:%s at %s', + rule.prefix, set.name, ann_key) + + -- Store the profile in Redis sorted set + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + function(err, _) + if err then + rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', + rule.prefix, set.name, profile.redis_key, err) + else + lua_util.debugm(N, task, 'created new ANN profile for %s:%s, data stored at prefix %s', + rule.prefix, set.name, profile.redis_key) + end + end, + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + ) + + -- Update redis_base to use the new ann_key + redis_base = ann_key + end + local function after_collect(vec) lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec)) if not vec then - if rule.providers and #rule.providers > 0 then - lua_util.debugm(N, task, - 'controller.neural: no vector from providers; skip training to keep dimensions consistent') - conn:send_error(400, 'no vector collected from providers') - return - else - vec = neural_common.result_to_vector(task, set) - end + lua_util.debugm(N, task, + 'controller.neural: no vector collected; skip training') + conn:send_error(400, 'no vector collected') + return end if type(vec) ~= 'table' then diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx index e2128f357a..c42d301429 100644 --- a/src/client/rspamc.cxx +++ b/src/client/rspamc.cxx @@ -212,6 +212,61 @@ static void rspamc_counters_output(FILE *out, ucl_object_t *obj); static void rspamc_stat_output(FILE *out, ucl_object_t *obj); +static void +rspamc_neural_learn_output(FILE *out, ucl_object_t *obj) +{ + bool is_success = true; + const char *filename = nullptr; + double scan_time = -1.0; + const char *redis_key = nullptr; + std::uintmax_t stored_bytes = 0; + bool have_stored = false; + + if (obj != nullptr) { + const auto *ok = ucl_object_lookup(obj, "success"); + if (ok) { + is_success = ucl_object_toboolean(ok); + } + const auto *fn = ucl_object_lookup(obj, "filename"); + if (fn) { + filename = ucl_object_tostring(fn); + } + const auto *st = ucl_object_lookup(obj, "scan_time"); + if (st) { + scan_time = ucl_object_todouble(st); + } + const auto *rb = ucl_object_lookup(obj, "stored"); + if (rb) { + stored_bytes = (std::uintmax_t) ucl_object_toint(rb); + have_stored = true; + } + const auto *rk = ucl_object_lookup(obj, "key"); + if (rk) { + redis_key = ucl_object_tostring(rk); + } + } + + // First line: success + fprintf(out, "success = %s;\n", is_success ? "true" : "false"); + + // Then other fields in k = v; format + if (filename) { + fprintf(out, "filename = \"%s\";\n", filename); + } + if (scan_time >= 0) { + fprintf(out, "scan_time = %.6f;\n", scan_time); + } + if (!neural_train.empty()) { + fprintf(out, "class = \"%s\";\n", neural_train.c_str()); + } + if (have_stored) { + fprintf(out, "stored = %ju bytes;\n", stored_bytes); + } + if (redis_key) { + fprintf(out, "key = \"%s\";\n", redis_key); + } +} + enum rspamc_command_type { RSPAMC_COMMAND_UNKNOWN = 0, RSPAMC_COMMAND_CHECK, @@ -288,7 +343,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of( .is_controller = FALSE, .is_privileged = FALSE, .need_input = TRUE, - .command_output_func = rspamc_symbols_output}, + .command_output_func = rspamc_neural_learn_output}, rspamc_command{ .cmd = RSPAMC_COMMAND_FUZZY_ADD, .name = "fuzzy_add", diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 1e8a135f18..3f0a3c7aa7 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -69,20 +69,20 @@ local function new_ann_profile(task, rule, set, version) local function add_cb(err, _) if err then rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', - rule.prefix, set.name, profile.redis_key, err) + rule.prefix, set.name, profile.redis_key, err) else rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', - rule.prefix, set.name, profile.redis_key) + rule.prefix, set.name, profile.redis_key) end end lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - add_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) return profile @@ -103,18 +103,18 @@ local function ann_scores_filter(task) profile = set.ann else lua_util.debugm(N, task, 'no ann loaded for %s:%s', - rule.prefix, set.name) + rule.prefix, set.name) end else lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', - rule.prefix, sid) + rule.prefix, sid) end if ann then local function after_features(vec, meta) if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply', - rule.prefix, set.name) + rule.prefix, set.name) vec = nil end @@ -131,7 +131,7 @@ local function ann_scores_filter(task) local symscore = string.format('%.3f', score) task:cache_set(rule.prefix .. '_neural_score', score) lua_util.debugm(N, task, '%s:%s:%s ann score: %s', - rule.prefix, set.name, set.ann.version, symscore) + rule.prefix, set.name, set.ann.version, symscore) if score > 0 then local result = score @@ -152,8 +152,8 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', - rule.prefix, set.name, set.ann.version, symscore, - spam_threshold) + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) end else local result = -(score) @@ -174,8 +174,8 @@ local function ann_scores_filter(task) end else lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', - rule.prefix, set.name, set.ann.version, result, - ham_threshold) + rule.prefix, set.name, set.ann.version, result, + ham_threshold) end end end @@ -214,20 +214,33 @@ local function ann_push_task_result(rule, task, verdict, score, set) end end + -- If LLM provider is configured, suppress autotrain unless manual training requested + if not manual_train and rule.providers and #rule.providers > 0 then + for _, p in ipairs(rule.providers) do + if p.type == 'llm' then + lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header') + learn_spam = false + learn_ham = false + skip_reason = 'llm provider requires manual training' + break + end + end + end + if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then if train_opts.spam_score then learn_spam = score >= train_opts.spam_score if not learn_spam then skip_reason = string.format('score < spam_score: %f < %f', - score, train_opts.spam_score) + score, train_opts.spam_score) end else learn_spam = verdict == 'spam' or verdict == 'junk' if not learn_spam then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end @@ -235,14 +248,14 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_ham = score <= train_opts.ham_score if not learn_ham then skip_reason = string.format('score > ham_score: %f > %f', - score, train_opts.ham_score) + score, train_opts.ham_score) end else learn_ham = verdict == 'ham' if not learn_ham then skip_reason = string.format('verdict: %s', - verdict) + verdict) end end else @@ -281,56 +294,63 @@ local function ann_push_task_result(rule, task, verdict, score, set) local nspam, nham = data[1], data[2] if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then - local vec - if rule.providers and #rule.providers > 0 then - -- Note: this training path remains sync for now; vectors are pushed when computed - -- fall back to legacy vector; async training push will be added later - vec = neural_common.result_to_vector(task, set) - else - vec = neural_common.result_to_vector(task, set) - end + local function store_train_vec(vec) + if not vec then + lua_util.debugm(N, task, "no vector collected for training") + return + end - local str = rspamd_util.zstd_compress(table.concat(vec, ';')) - local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' - local function learn_vec_cb(redis_err) - if redis_err then - rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + local function learn_vec_cb(redis_err) + if redis_err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', rule.prefix, set.name, redis_err) - else - lua_util.debugm(N, task, + else + lua_util.debugm(N, task, "add train data for ANN rule " .. - "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", rule.prefix, set.name, learn_type, #vec, target_key, #str) + end end - end - lua_redis.redis_make_request(task, + lua_redis.redis_make_request(task, rule.redis, nil, - true, -- is write - learn_vec_cb, --callback - 'SADD', -- command + true, -- is write + learn_vec_cb, --callback + 'SADD', -- command { target_key, str } -- arguments - ) + ) + end + + if rule.providers and #rule.providers > 0 then + -- Use async feature collection with providers, same as inference + neural_common.collect_features_async(task, rule, set, 'train', store_train_vec) + else + -- Traditional symbol-based vector + local vec = neural_common.result_to_vector(task, set) + store_train_vec(vec) + end else lua_util.debugm(N, task, - "do not add %s train data for ANN rule " .. - "%s:%s", - learn_type, rule.prefix, set.name) + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) end else if err then rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', - rule.prefix, set.name, err) + rule.prefix, set.name, err) elseif type(data) == 'string' then -- nil return value rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s", - learn_type, rule.prefix, set.name, set.ann.redis_key, data) + learn_type, rule.prefix, set.name, set.ann.redis_key, data) else rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. - 'please remove this key from Redis manually if you perform upgrade from the previous version', - rule.prefix, set.name, set.ann.redis_key, type(data)) + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) end end end @@ -341,25 +361,25 @@ local function ann_push_task_result(rule, task, verdict, score, set) -- Need to create or load a profile corresponding to the current configuration set.ann = new_ann_profile(task, rule, set, 0) lua_util.debugm(N, task, - 'requested new profile for %s, set.ann is missing', - set.name) + 'requested new profile for %s, set.ann is missing', + set.name) end lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, - { task = task, is_write = false }, - vectors_len_cb, - { - set.ann.redis_key, - }) + { task = task, is_write = false }, + vectors_len_cb, + { + set.ann.redis_key, + }) else lua_util.debugm(N, task, - 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') end else lua_util.debugm(N, task, - 'do not push data to key %s: train condition not satisfied; reason: %s', - (set.ann or {}).redis_key, - skip_reason) + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) end end @@ -380,20 +400,21 @@ end local function do_train_ann(worker, ev_base, rule, set, ann_key) local spam_elts = {} local ham_elts = {} + lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key) local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector @@ -414,29 +435,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_spam_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', - ann_key, err) + ann_key, err) -- Unlock ANN on error lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - neural_common.gen_unlock_cb(rule, set, ann_key), --callback - 'HDEL', -- command - { ann_key, 'lock' } + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } ) else -- Decompress and convert to numbers each training vector spam_elts = process_training_vectors(data) -- Now get ham vectors... lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_ham_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_ham_set' } ) end end @@ -444,33 +465,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) local function redis_lock_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', - ann_key, err) + ann_key, err) elseif type(data) == 'number' and data == 1 then -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_spam_cb, --callback - 'SMEMBERS', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_spam_set' } ) rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) else local lock_tm = tonumber(data[1]) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. - 'locked by another host %s at %s', rule.prefix, set.name, ann_key, - data[2], os.date('%c', lock_tm)) + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) end end -- Check if we are already learning this network if set.learning_spawned then rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', - ann_key) + ann_key) return end @@ -478,14 +499,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key) -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when -- ANN is locked by another host (or a process, meh) lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, - { ev_base = ev_base, is_write = true }, - redis_lock_cb, - { - ann_key, - tostring(os.time()), - tostring(math.max(10.0, rule.watch_interval * 2)), - rspamd_util.get_hostname() - }) + { ev_base = ev_base, is_write = true }, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(math.max(10.0, rule.watch_interval * 2)), + rspamd_util.get_hostname() + }) end -- This function loads new ann from Redis @@ -500,7 +521,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local function data_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', - ann_key, err) + ann_key, err) else if type(data) == 'table' then if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then @@ -509,7 +530,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) if _err or not ann_data then rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', - rule.prefix .. ':' .. set.name, ann_key, _err) + rule.prefix .. ':' .. set.name, ann_key, _err) return else ann = rspamd_kann.load(ann_data) @@ -533,26 +554,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end -- Also update rank for the loaded ANN to avoid removal lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - rank_cb, --callback - 'ZADD', -- command - { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } ) rspamd_logger.infox(rspamd_config, - 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[1], profile.version) + 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[1], profile.version) else rspamd_logger.errx(rspamd_config, - 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', - rule.prefix, set.name, ann_key) + 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) end end else lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then @@ -564,8 +585,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) local roc_thresholds = parser:get_object() set.ann.roc_thresholds = roc_thresholds rspamd_logger.infox(rspamd_config, - 'loaded ROC thresholds for %s:%s; version=%s', - rule.prefix, set.name, profile.version) + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) end end @@ -578,19 +599,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) -- We can use PCA set.ann.pca = rspamd_tensor.load(pca_data) rspamd_logger.infox(rspamd_config, - 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[3], profile.version) + 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[3], profile.version) else -- no need in pca, why is it there? rspamd_logger.warnx(rspamd_config, - 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', - rule.prefix, set.name, ann_key) + 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) end else -- pca can be missing merely if we have no max_inputs if rule.max_inputs then rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', - rule.prefix, set.name, ann_key, _err) + rule.prefix, set.name, ann_key, _err) set.ann.ann = nil else -- It is okay @@ -619,19 +640,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end else lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, ann_key) end end end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - data_cb, --callback - 'HMGET', -- command - { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments - { opaque_data = true } + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HMGET', -- command + { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments + { opaque_data = true } ) end @@ -644,6 +665,8 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) local my_symbols = set.symbols local min_diff = math.huge local sel_elt + lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s', + type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name) for _, elt in fun.iter(profiles) do if elt and elt.symbols then @@ -667,34 +690,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) if set.ann.version < sel_elt.version then -- Load new ann rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. - 'our version = %s, remote version = %s', - rule.prefix .. ':' .. set.name, - set.ann.version, - sel_elt.version) + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) end else -- We have some different ANN, so we need to compare distance if set.ann.distance > min_diff then -- Load more specific ANN rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) load_new_ann(rule, ev_base, set, sel_elt, min_diff) else lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. - 'our distance = %s, remote distance = %s', - rule.prefix .. ':' .. set.name, - set.ann.distance, - min_diff) + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) end end else @@ -702,6 +725,12 @@ local function process_existing_ann(_, ev_base, rule, set, profiles) load_new_ann(rule, ev_base, set, sel_elt, min_diff) end end + if sel_elt then + lua_util.debugm(N, rspamd_config, 'process_existing_ann: selected profile version=%s key=%s', sel_elt.version, + sel_elt.redis_key) + else + lua_util.debugm(N, rspamd_config, 'process_existing_ann: no suitable profile found') + end end @@ -715,6 +744,8 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) spam = 0, ham = 0, } + lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s', + type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name) for _, elt in fun.iter(profiles) do if elt and elt.symbols then @@ -732,14 +763,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local ann_key = sel_elt.redis_key lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", - ann_key) + ann_key) -- Create continuation closure local redis_len_cb_gen = function(cont_cb, what, is_final) return function(err, data) if err then rspamd_logger.errx(rspamd_config, - 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) elseif data and type(data) == 'number' or type(data) == 'string' then local ntrains = tonumber(data) or 0 lens[what] = ntrains @@ -760,31 +791,31 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) end if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end else -- Probabilistic mode, just ensure that at least one vector is okay if min_len > 0 and max_len >= rule.train.max_trains then lua_util.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) cont_cb() else lua_util.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end end else lua_util.debugm(N, rspamd_config, - 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', - what, ann_key, ntrains, rule.train.max_trains) + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) cont_cb() end end @@ -793,32 +824,34 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local function initiate_train() rspamd_logger.infox(rspamd_config, - 'need to learn ANN %s after %s required learn vectors', - ann_key, lens) + 'need to learn ANN %s after %s required learn vectors', + ann_key, lens) + lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key, + lens.spam or -1, lens.ham or -1) do_train_ann(worker, ev_base, rule, set, ann_key) end -- Spam vector is OK, check ham vector length local function check_ham_len() lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(initiate_train, 'ham', true), --callback - 'SCARD', -- command - { ann_key .. '_ham_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'SCARD', -- command + { ann_key .. '_ham_set' } ) end lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_len_cb_gen(check_ham_len, 'spam', false), --callback - 'SCARD', -- command - { ann_key .. '_spam_set' } + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'SCARD', -- command + { ann_key .. '_spam_set' } ) end end @@ -831,7 +864,7 @@ local function load_ann_profile(element) local res, ucl_err = parser:parse_string(element) if not res then rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', - ucl_err) + ucl_err) return nil else local profile = parser:get_object() @@ -851,13 +884,16 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) local function members_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', - err) + err) set.can_store_vectors = true elseif type(data) == 'table' then - lua_util.debugm(N, cfg, '%s: process element %s:%s', - what, rule.prefix, set.name) + lua_util.debugm(N, cfg, '%s: process element %s:%s (profiles=%s)', + what, rule.prefix, set.name, #data) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) set.can_store_vectors = true + else + lua_util.debugm(N, cfg, '%s: no profiles for %s:%s', what, rule.prefix, set.name) + set.can_store_vectors = true end end @@ -867,13 +903,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what) -- Select the most appropriate to our profile but it should not differ by more -- than 30% of symbols lua_redis.redis_make_request_taskless(ev_base, - cfg, - rule.redis, - nil, - false, -- is write - members_cb, --callback - 'ZREVRANGE', -- command - { set.prefix, '0', tostring(settings.max_profiles) } -- arguments + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + { set.prefix, '0', tostring(settings.max_profiles) } -- arguments ) end end -- Cycle over all settings @@ -887,23 +923,23 @@ local function cleanup_anns(rule, cfg, ev_base) local function invalidate_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', - err) + err) elseif type(data) == 'table' then for _, expired in ipairs(data) do local profile = load_ann_profile(expired) rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', - rule.prefix .. ':' .. set.name, - profile.redis_key, - profile.version) + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) end end end if type(set) == 'table' then lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, - { ev_base = ev_base, is_write = true }, - invalidate_cb, - { set.prefix, tostring(settings.max_profiles) }) + { ev_base = ev_base, is_write = true }, + invalidate_cb, + { set.prefix, tostring(settings.max_profiles) }) end end end @@ -922,14 +958,14 @@ local function ann_push_vector(task) if verdict == 'passthrough' then lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', - verdict, score) + verdict, score) return end if score ~= score then lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', - verdict) + verdict) return end @@ -999,7 +1035,7 @@ for k, r in pairs(rules) do if rule_elt.max_inputs and not has_blas then rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in', - rule_elt.name, rule_elt.max_inputs) + rule_elt.name, rule_elt.max_inputs) rule_elt.max_inputs = nil end @@ -1011,7 +1047,7 @@ for k, r in pairs(rules) do end if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then rspamd_logger.errx(rspamd_config, - 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k) + 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k) end end end @@ -1062,21 +1098,21 @@ for _, rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) if worker:is_scanner() then rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - return check_anns(worker, cfg, ev_base, rule, process_existing_ann, - 'try_load_ann') - end) + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) end if worker:is_primary_controller() then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - -- Clean old ANNs - cleanup_anns(rule, cfg, ev_base) - return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, - 'try_train_ann') - end) + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) end end) end diff --git a/test/functional/cases/335_neural_llm/003_llm_train.robot b/test/functional/cases/335_neural_llm/003_llm_train.robot new file mode 100644 index 0000000000..aa76a153e9 --- /dev/null +++ b/test/functional/cases/335_neural_llm/003_llm_train.robot @@ -0,0 +1,39 @@ +*** Settings *** +Suite Setup Rspamd Redis Setup +Suite Teardown Rspamd Redis Teardown +Library Process +Library ${RSPAMD_TESTDIR}/lib/rspamd.py +Resource ${RSPAMD_TESTDIR}/lib/rspamd.robot +Variables ${RSPAMD_TESTDIR}/lib/vars.py + +*** Variables *** +${CONFIG} ${RSPAMD_TESTDIR}/configs/neural_llm.conf +${SPAM_MSG} ${RSPAMD_TESTDIR}/messages/spam_message.eml +${HAM_MSG} ${RSPAMD_TESTDIR}/messages/ham.eml +${REDIS_SCOPE} Suite +${RSPAMD_SCOPE} Suite +${RSPAMD_URL_TLD} ${RSPAMD_TESTDIR}/../lua/unit/test_tld.dat + +*** Test Cases *** +Train LLM-backed neural and verify + Run Dummy Llm + + # Learn spam + ${result} = Run Rspamc -P secret -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} neural_learn:spam ${SPAM_MSG} + Check Rspamc ${result} + + # Learn ham + ${result} = Run Rspamc -P secret -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} neural_learn:ham ${HAM_MSG} + Check Rspamc ${result} + + Sleep 5s + + # Check spam inference (dummy_llm returns ones vector for "spam" content) + Scan File ${SPAM_MSG} Settings={groups_enabled=["neural"]} + Expect Symbol NEURAL_SPAM + + # Check ham inference (zeros vector) + Scan File ${HAM_MSG} Settings={groups_enabled=["neural"]} + Expect Symbol NEURAL_HAM + + Dummy Llm Teardown diff --git a/test/functional/configs/neural_llm.conf b/test/functional/configs/neural_llm.conf new file mode 100644 index 0000000000..b6745adee3 --- /dev/null +++ b/test/functional/configs/neural_llm.conf @@ -0,0 +1,68 @@ +options = { + url_tld = "{= env.URL_TLD =}" + pidfile = "{= env.TMPDIR =}/rspamd.pid" + lua_path = "{= env.INSTALLROOT =}/share/rspamd/lib/?.lua" + filters = []; + explicit_modules = ["settings"]; +} + +logging = { + type = "file", + level = "debug" + filename = "{= env.TMPDIR =}/rspamd.log" + log_usec = true; +} +metric = { + name = "default", + actions = { + reject = 100500, + add_header = 50500, + } + unknown_weight = 1 +} +worker { + type = normal + bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_NORMAL =}" + count = 1 + task_timeout = 10s; +} +worker { + type = controller + bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_CONTROLLER =}" + count = 1 + secure_ip = ["127.0.0.1", "::1"]; + stats_path = "{= env.TMPDIR =}/stats.ucl" +} + +modules { + path = "{= env.TESTDIR =}/../../src/plugins/lua/" +} + +lua = "{= env.TESTDIR =}/lua/test_coverage.lua"; + +neural { + rules { + default { + train { + learning_rate = 0.001; + max_trains = 1; + max_iterations = 250; + } + symbol_spam = "NEURAL_SPAM"; + symbol_ham = "NEURAL_HAM"; + ann_expire = 86400; + watch_interval = 0.5; + providers = [{ type = "llm"; model = "dummy-embed"; url = "http://127.0.0.1:18080"; weight = 1.0; }]; + fusion { normalization = "none"; } + roc_enabled = false; + } + } + allow_local = true; +} + +redis { + servers = "{= env.REDIS_ADDR =}:{= env.REDIS_PORT =}"; + expand_keys = true; +} + +lua = "{= env.TESTDIR =}/lua/neural.lua"; diff --git a/test/functional/util/dummy_llm.py b/test/functional/util/dummy_llm.py new file mode 100644 index 0000000000..9ee0f17726 --- /dev/null +++ b/test/functional/util/dummy_llm.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import json +import sys +from http.server import BaseHTTPRequestHandler, HTTPServer + +import dummy_killer + +PID = "/tmp/dummy_llm.pid" + + +def make_embedding(text: str, dim: int = 32): + # Deterministic: if text contains 'SPAM' (case-insensitive) -> ones; else zeros + if 'spam' in text.lower(): + return [1.0] * dim + return [0.0] * dim + + +class EmbeddingHandler(BaseHTTPRequestHandler): + # OpenAI-like embeddings API + def do_POST(self): + length = int(self.headers.get('Content-Length', '0')) + raw = self.rfile.read(length) if length > 0 else b'' + try: + data = json.loads(raw.decode('utf-8') or '{}') + except Exception: + data = {} + + # Support both OpenAI ({input, model}) and Ollama ({prompt, model}) shapes + text = data.get('input') or data.get('prompt') or '' + # Optional dimension override for tests + dim = int(data.get('dim') or 32) + emb = make_embedding(text, dim) + + if 'openai' in (self.headers.get('User-Agent') or '').lower() or True: + # Always reply in OpenAI-like format expected by neural provider + body = { + "data": [ + {"embedding": emb} + ] + } + else: + body = {"embedding": emb} + + reply = json.dumps(body).encode('utf-8') + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.send_header('Content-Length', str(len(reply))) + self.end_headers() + self.wfile.write(reply) + + def log_message(self, fmt, *args): + # Keep test output quiet + return + + +if __name__ == "__main__": + alen = len(sys.argv) + if alen > 1: + port = int(sys.argv[1]) + else: + port = 18080 + server = HTTPServer(("127.0.0.1", port), EmbeddingHandler) + dummy_killer.write_pid(PID) + try: + server.serve_forever() + except KeyboardInterrupt: + pass + finally: + server.server_close()