WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-]]--
+]] --
--[[[
-- @module lua_cache
-- Default options
local default_opts = {
cache_prefix = "rspamd_cache",
- cache_ttl = 3600, -- 1 hour
- cache_probes = 5, -- Number of times to check a pending key
- cache_format = "json", -- Serialization format
- cache_hash_len = 16, -- Number of hex symbols to use for hashed keys
+ cache_ttl = 3600, -- 1 hour
+ cache_probes = 5, -- Number of times to check a pending key
+ cache_format = "json", -- Serialization format
+ cache_hash_len = 16, -- Number of hex symbols to use for hashed keys
cache_use_hashing = false -- Whether to hash keys by default
}
end
if should_hash then
- lua_util.debugm(N, rspamd_config, "hashing key '%s' with hash length %s",
- raw_key, cache_context.opts.cache_hash_len)
+ local raw_len = (type(raw_key) == 'string') and #raw_key or -1
+ lua_util.debugm(N, rspamd_config, "hashing cache key (len=%s) with hash length %s",
+ raw_len, cache_context.opts.cache_hash_len)
return hash_key(raw_key, cache_context.opts.cache_hash_len)
else
return raw_key
-- Register Redis prefix
lua_redis.register_prefix(cache_context.opts.cache_prefix,
- "caching",
- "Cache API prefix")
+ "caching",
+ "Cache API prefix")
lua_util.debugm(N, rspamd_config, "registered redis prefix: %s", cache_context.opts.cache_prefix)
}
lua_util.debugm(cache_context.N, rspamd_config, "creating PENDING marker for host %s, timeout %s",
- hostname, timeout)
+ hostname, timeout)
return "PENDING:" .. encode_data(pending_data, cache_context)
end
return false
end
- local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
- lua_util.debugm(cache_context.N, task, "cache lookup for key: %s (%s)", key, full_key)
+ local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
+ lua_util.debugm(cache_context.N, task, "cache lookup for key: %s", full_key)
-- Function to check a pending key
local function check_pending(pending_info)
local probe_interval = timeout / (cache_context.opts.cache_probes or 5)
lua_util.debugm(cache_context.N, task, "setting up probes for pending key %s, interval: %s seconds",
- full_key, probe_interval)
+ full_key, probe_interval)
-- Set up a timer to probe the key
local function probe_key()
probe_count = probe_count + 1
lua_util.debugm(cache_context.N, task, "probe #%s/%s for pending key %s",
- probe_count, cache_context.opts.cache_probes, full_key)
+ probe_count, cache_context.opts.cache_probes, full_key)
if probe_count >= cache_context.opts.cache_probes then
logger.infox(task, "maximum probes reached for key %s, considering it failed", full_key)
lua_util.debugm(cache_context.N, task, "probing redis for key %s", full_key)
lua_redis.redis_make_request(task, cache_context.redis_params, key, false,
- function(err, data)
- if err then
- logger.errx(task, "redis error while probing key %s: %s", full_key, err)
- lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err)
- task:add_timer(probe_interval, probe_key)
- return
- end
+ function(err, data)
+ if err then
+ logger.errx(task, "redis error while probing key %s: %s", full_key, err)
+ lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err)
+ task:add_timer(probe_interval, probe_key)
+ return
+ end
- if not data or type(data) == 'userdata' then
- lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key)
- callback_uncached(task)
- return
- end
+ if not data or type(data) == 'userdata' then
+ lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key)
+ callback_uncached(task)
+ return
+ end
- local pending = parse_pending_value(data, cache_context)
- if pending then
- lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later",
- full_key, pending.hostname)
- task:add_timer(probe_interval, probe_key)
- else
- lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key)
- callback_data(task, nil, decode_data(data, cache_context))
- end
- end,
- 'GET', { full_key }
+ local pending = parse_pending_value(data, cache_context)
+ if pending then
+ lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later",
+ full_key, pending.hostname)
+ task:add_timer(probe_interval, probe_key)
+ else
+ lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key)
+ callback_data(task, nil, decode_data(data, cache_context))
+ end
+ end,
+ 'GET', { full_key }
)
end
-- Start the first probe after the initial probe interval
lua_util.debugm(cache_context.N, task, "scheduling first probe for %s in %s seconds",
- full_key, probe_interval)
+ full_key, probe_interval)
task:add_timer(probe_interval, probe_key)
end
-- Initial cache lookup
lua_util.debugm(cache_context.N, task, "making initial redis GET request for key: %s", full_key)
lua_redis.redis_make_request(task, cache_context.redis_params, key, false,
- function(err, data)
- if err then
- logger.errx(task, "redis error looking up key %s: %s", full_key, err)
- lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err)
- callback_uncached(task)
- return
- end
-
- if not data or type(data) == 'userdata' then
- -- Key not found, set pending and call the uncached callback
- lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key)
- local pending_marker = create_pending_marker(timeout, cache_context)
+ function(err, data)
+ if err then
+ logger.errx(task, "redis error looking up key %s: %s", full_key, err)
+ lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err)
+ callback_uncached(task)
+ return
+ end
- lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s",
- full_key, timeout * 2)
+ if not data or type(data) == 'userdata' then
+ -- Key not found, set pending and call the uncached callback
+ lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key)
+ local pending_marker = create_pending_marker(timeout, cache_context)
+
+ lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s",
+ full_key, timeout * 2)
+ lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
+ function(set_err, set_data)
+ if set_err then
+ logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err)
+ lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err)
+ else
+ lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key)
+ end
+ lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key)
+ callback_uncached(task)
+ end,
+ 'SETEX', { full_key, tostring(timeout * 2), pending_marker }
+ )
+ else
+ -- Key found, check if it's a pending marker or actual data
+ local pending = parse_pending_value(data, cache_context)
+
+ if pending then
+ -- Key is being processed by another worker
+ lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result",
+ full_key, pending.hostname)
+ check_pending(pending)
+ else
+ -- Extend TTL and return data
+ lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s",
+ full_key, cache_context.opts.cache_ttl)
lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
- function(set_err, set_data)
- if set_err then
- logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err)
- lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err)
- else
- lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key)
- end
- lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key)
- callback_uncached(task)
- end,
- 'SETEX', { full_key, tostring(timeout * 2), pending_marker }
+ function(expire_err, _)
+ if expire_err then
+ logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err)
+ lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err)
+ else
+ lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key)
+ end
+ end,
+ 'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) }
)
- else
- -- Key found, check if it's a pending marker or actual data
- local pending = parse_pending_value(data, cache_context)
- if pending then
- -- Key is being processed by another worker
- lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result",
- full_key, pending.hostname)
- check_pending(pending)
- else
- -- Extend TTL and return data
- lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s",
- full_key, cache_context.opts.cache_ttl)
- lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
- function(expire_err, _)
- if expire_err then
- logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err)
- lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err)
- else
- lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key)
- end
- end,
- 'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) }
- )
-
- lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key)
- callback_data(task, nil, decode_data(data, cache_context))
- end
+ lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key)
+ callback_data(task, nil, decode_data(data, cache_context))
end
- end,
- 'GET', { full_key }
+ end
+ end,
+ 'GET', { full_key }
)
return true
return false
end
- local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
- lua_util.debugm(cache_context.N, task, "caching data for key: %s (%s) with TTL: %s",
- full_key, key, cache_context.opts.cache_ttl)
+ local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
+ lua_util.debugm(cache_context.N, task, "caching data for key: %s with TTL: %s",
+ full_key, cache_context.opts.cache_ttl)
local encoded_data = encode_data(data, cache_context)
-- Store the data with expiration
lua_util.debugm(cache_context.N, task, "making redis SETEX request for key: %s", full_key)
return lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
- function(err, result)
- if err then
- logger.errx(task, "redis error setting cached data for %s: %s", full_key, err)
- lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err)
- else
- lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key)
- end
- end,
- 'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data }
+ function(err, result)
+ if err then
+ logger.errx(task, "redis error setting cached data for %s: %s", full_key, err)
+ lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err)
+ else
+ lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key)
+ end
+ end,
+ 'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data }
)
end
return false
end
- local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
+ local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
lua_util.debugm(cache_context.N, task, "deleting cache key: %s", full_key)
return lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
- function(err, result)
- if err then
- logger.errx(task, "redis error deleting cache key %s: %s", full_key, err)
- lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err)
- else
- local count = tonumber(result) or 0
- lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)",
- full_key, count)
- end
- end,
- 'DEL', { full_key }
+ function(err, result)
+ if err then
+ logger.errx(task, "redis error deleting cache key %s: %s", full_key, err)
+ lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err)
+ else
+ local count = tonumber(result) or 0
+ lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)",
+ full_key, count)
+ end
+ end,
+ 'DEL', { full_key }
)
end
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-]]--
+]] --
local exports = {}
nlarge = 1.0 * nlarge / ntotal
nsmall = 1.0 * nsmall / ntotal
end
- return { ntotal, njpg, npng, nlarge, nsmall }
+ return { ntotal, npng, njpg, nlarge, nsmall } -- Fixed order to match names
end
local function meta_nparts_function(task)
local fun = require "fun"
if rh and #rh > 0 then
-
local ntotal = 0.0
local init_time = 0
fun.each(function(rc)
- ntotal = ntotal + 1.0
+ ntotal = ntotal + 1.0
- if not rc.by_hostname then
- invalid_factor = invalid_factor + 1.0
- end
- if init_time == 0 and rc.timestamp then
- init_time = rc.timestamp
- elseif rc.timestamp then
- time_factor = time_factor + math.abs(init_time - rc.timestamp)
- init_time = rc.timestamp
- end
- if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then
- secure_factor = secure_factor + 1.0
- end
- end,
- fun.filter(function(rc)
- return not rc.flags or not rc.flags['artificial']
- end, rh))
+ if not rc.by_hostname then
+ invalid_factor = invalid_factor + 1.0
+ end
+ if init_time == 0 and rc.timestamp then
+ init_time = rc.timestamp
+ elseif rc.timestamp then
+ time_factor = time_factor + math.abs(init_time - rc.timestamp)
+ init_time = rc.timestamp
+ end
+ if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then
+ secure_factor = secure_factor + 1.0
+ end
+ end,
+ fun.filter(function(rc)
+ return not rc.flags or not rc.flags['artificial']
+ end, rh))
if ntotal > 0 then
invalid_factor = invalid_factor / ntotal
end
local ret = {
- short_words,
- ret_len,
+ ret_len, -- avg_words_len (moved to match the names array)
+ short_words, -- nshort_words
}
local divisor = 1.0
local ct = mt.cb(task)
for i, tok in ipairs(ct) do
lua_util.debugm(N, task, "metatoken: %s = %s",
- mt.names[i], tok)
+ mt.names[i], tok)
if tok ~= tok or tok == math.huge then
logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
- mt.names[i], tok)
+ mt.names[i], tok)
tok = 0.0
end
table.insert(metatokens, tok)
task:cache_set('metatokens', metatokens)
end
-
else
for _, n in ipairs(names) do
if metatokens_by_name[n] then
local tok = metatokens_by_name[n](task)
if tok ~= tok or tok == math.huge then
logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
- n, tok)
+ n, tok)
tok = 0.0
end
table.insert(metatokens, tok)
for i, tok in ipairs(ct) do
if tok ~= tok or tok == math.huge then
logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
- mt.names[i], tok)
+ mt.names[i], tok)
tok = 0.0
end
end,
})
+-- Metatokens-only provider for contexts where symbols are not available
+register_provider('metatokens', {
+ collect = function(task, ctx)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+ -- Convert to table of numbers
+ local vec = {}
+ for i = 1, #mt do
+ vec[i] = tonumber(mt[i]) or 0.0
+ end
+ return vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 }
+ end,
+ collect_async = function(task, ctx, cont)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+ -- Convert to table of numbers
+ local vec = {}
+ for i = 1, #mt do
+ vec[i] = tonumber(mt[i]) or 0.0
+ end
+ cont(vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 })
+ end,
+})
+
local function load_scripts()
redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
redis_params)
local function redis_ann_prefix(rule, settings_name)
-- We also need to count metatokens:
+ -- Note: meta_functions.version represents the metatoken format version
local n = meta_functions.version
return string.format('%s%d_%s_%d_%s',
settings.prefix, plugin_ver, rule.prefix, n, settings_name)
end
prov.collect_async(task, {
profile = profile_or_set,
+ set = profile_or_set,
rule = rule,
config = pcfg,
weight = pcfg.weight or 1.0,
end)
end
- -- Include metatokens as an extra provider when configured
- local include_meta = rule.fusion and rule.fusion.include_meta
+ -- Include symbols provider (which includes both symbols AND metatokens) as an extra provider
+ -- The name 'include_meta' is historical but it actually includes the full symbols provider
+ -- For backward compatibility, include symbols by default unless explicitly disabled
+ local include_meta = false
+ if not providers_cfg or #providers_cfg == 0 then
+ -- No providers, always use symbols (which includes metatokens)
+ include_meta = true
+ elseif rule.fusion then
+ -- Explicit fusion config takes precedence
+ include_meta = rule.fusion.include_meta
+ if include_meta == nil then
+ -- Default to true for backward compatibility when fusion is configured but include_meta not specified
+ include_meta = true
+ end
+ else
+ -- Providers configured but no fusion settings - default to including symbols+metatokens
+ include_meta = true
+ end
+
local meta_weight = (rule.fusion and rule.fusion.meta_weight) or 1.0
remaining = #providers_cfg + (include_meta and 1 or 0)
+
+ -- Start all configured providers
for i, pcfg in ipairs(providers_cfg) do
start_provider(i, pcfg)
end
if include_meta then
- local prov = get_provider('symbols')
+ -- Always use metatokens provider for consistency
+ -- This ensures same dimensions whether called from controller or full scan
+ local prov = get_provider('metatokens')
+
if prov and prov.collect_async then
- prov.collect_async(task, { profile = profile_or_set, weight = meta_weight, phase = phase }, function(vec, meta)
- if vec then
- metas[#metas + 1] = { name = 'symbols', type = 'symbols', dim = #vec, weight = meta_weight }
- vectors[#vectors + 1] = vec
- end
- maybe_finish()
- end)
+ local meta_index = #providers_cfg + 1 -- Metatokens always come after providers
+ prov.collect_async(task, { profile = profile_or_set, set = profile_or_set, weight = meta_weight, phase = phase },
+ function(vec, meta)
+ if vec then
+ metas[meta_index] = meta
+ vectors[meta_index] = vec
+ end
+ maybe_finish()
+ end)
else
maybe_finish()
end
local function spawn_train(params)
-- Check training data sanity
-- Now we need to join inputs and create the appropriate test vectors
- local n = #params.set.symbols +
- meta_functions.rspamd_count_metatokens()
+ local n
+
+ -- When using providers, derive dimension from actual vectors
+ if params.rule.providers and #params.rule.providers > 0 and
+ (#params.spam_vec > 0 or #params.ham_vec > 0) then
+ -- Use dimension from stored vectors
+ if #params.spam_vec > 0 then
+ n = #params.spam_vec[1]
+ else
+ n = #params.ham_vec[1]
+ end
+ lua_util.debugm(N, rspamd_config, 'spawn_train: using vector dimension %s from stored vectors', n)
+ else
+ -- Traditional symbol-based dimension
+ n = #params.set.symbols + meta_functions.rspamd_count_metatokens()
+ lua_util.debugm(N, rspamd_config, 'spawn_train: using symbol dimension %s symbols + %s metatokens = %s',
+ #params.set.symbols, meta_functions.rspamd_count_metatokens(), n)
+ end
-- Now we can train ann
local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
if not profile.zeros then
-- Fill zeros vector
local zeros = {}
- for i = 1, meta_functions.count_metatokens() do
+ for i = 1, meta_functions.rspamd_count_metatokens() do
zeros[i] = 0.0
end
for _, _ in ipairs(profile.symbols) do
return
end
+ -- Do not run embeddings on infer if ANN is not loaded for this set/profile
+ if ctx.phase == 'infer' then
+ local set_or_profile = ctx.profile or ctx.set
+ if not set_or_profile or not set_or_profile.ann then
+ rspamd_logger.debugm(N, task, 'skip llm on infer: ANN not loaded for current settings')
+ cont(nil)
+ return
+ end
+ end
+
local input_tbl = select_text(task)
if not input_tbl then
rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
return
end
- -- If no providers or symbols provider configured, require full scan path
+ -- Check if this configuration requires full scan
+ -- Only symbols collection requires full scan; metatokens can be computed directly
local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
- if not has_providers then
- lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s',
+
+ if not has_providers and not rule.disable_symbols_input then
+ -- No providers means full symbols will be used (not just metatokens)
+ lua_util.debugm(N, task,
+ 'controller.neural: learn_message refused: no providers configured, symbols collection requires full scan for rule=%s',
rule_name)
- conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+ conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured, full symbols collection required)')
return
end
- for _, p in ipairs(rule.providers) do
- if p.type == 'symbols' then
- lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name)
- conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
- return
+
+ -- Check if any provider requires full scan (only symbols provider does)
+ if has_providers then
+ for _, p in ipairs(rule.providers) do
+ if p.type == 'symbols' then
+ lua_util.debugm(N, task,
+ 'controller.neural: learn_message refused due to symbols provider requiring full scan for rule=%s',
+ rule_name)
+ conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
+ return
+ end
end
end
+ -- At this point:
+ -- - We have providers that don't require full scan (e.g., LLM)
+ -- - Metatokens can be computed directly from the message
+ -- - Controller training is allowed
+
local set = neural_common.get_rule_settings(task, rule)
if not set then
lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set',
end
end
+ if set then
+ lua_util.debugm(N, task, 'controller.neural: set found for rule=%s, symbols=%s, name=%s',
+ rule_name, set.symbols and #set.symbols or "nil", set.name)
+ end
+
-- Derive redis base key even if ANN not yet initialized
local redis_base
if set and set.ann and set.ann.redis_key then
return
end
+ -- Ensure profile exists for this set
+ if not set.ann then
+ local version = 0
+ local ann_key = neural_common.new_ann_key(rule, set, version)
+
+ local profile = {
+ symbols = set.symbols,
+ redis_key = ann_key,
+ version = version,
+ digest = set.digest,
+ distance = 0,
+ providers_digest = neural_common.providers_config_digest(rule.providers),
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+ lua_util.debugm(N, task, 'controller.neural: creating new profile for %s:%s at %s',
+ rule.prefix, set.name, ann_key)
+
+ -- Store the profile in Redis sorted set
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true, -- is write
+ function(err, _)
+ if err then
+ rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
+ rule.prefix, set.name, profile.redis_key, err)
+ else
+ lua_util.debugm(N, task, 'created new ANN profile for %s:%s, data stored at prefix %s',
+ rule.prefix, set.name, profile.redis_key)
+ end
+ end,
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ )
+
+ -- Update redis_base to use the new ann_key
+ redis_base = ann_key
+ end
+
local function after_collect(vec)
lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
if not vec then
- if rule.providers and #rule.providers > 0 then
- lua_util.debugm(N, task,
- 'controller.neural: no vector from providers; skip training to keep dimensions consistent')
- conn:send_error(400, 'no vector collected from providers')
- return
- else
- vec = neural_common.result_to_vector(task, set)
- end
+ lua_util.debugm(N, task,
+ 'controller.neural: no vector collected; skip training')
+ conn:send_error(400, 'no vector collected')
+ return
end
if type(vec) ~= 'table' then
static void rspamc_stat_output(FILE *out, ucl_object_t *obj);
+static void
+rspamc_neural_learn_output(FILE *out, ucl_object_t *obj)
+{
+ bool is_success = true;
+ const char *filename = nullptr;
+ double scan_time = -1.0;
+ const char *redis_key = nullptr;
+ std::uintmax_t stored_bytes = 0;
+ bool have_stored = false;
+
+ if (obj != nullptr) {
+ const auto *ok = ucl_object_lookup(obj, "success");
+ if (ok) {
+ is_success = ucl_object_toboolean(ok);
+ }
+ const auto *fn = ucl_object_lookup(obj, "filename");
+ if (fn) {
+ filename = ucl_object_tostring(fn);
+ }
+ const auto *st = ucl_object_lookup(obj, "scan_time");
+ if (st) {
+ scan_time = ucl_object_todouble(st);
+ }
+ const auto *rb = ucl_object_lookup(obj, "stored");
+ if (rb) {
+ stored_bytes = (std::uintmax_t) ucl_object_toint(rb);
+ have_stored = true;
+ }
+ const auto *rk = ucl_object_lookup(obj, "key");
+ if (rk) {
+ redis_key = ucl_object_tostring(rk);
+ }
+ }
+
+ // First line: success
+ fprintf(out, "success = %s;\n", is_success ? "true" : "false");
+
+ // Then other fields in k = v; format
+ if (filename) {
+ fprintf(out, "filename = \"%s\";\n", filename);
+ }
+ if (scan_time >= 0) {
+ fprintf(out, "scan_time = %.6f;\n", scan_time);
+ }
+ if (!neural_train.empty()) {
+ fprintf(out, "class = \"%s\";\n", neural_train.c_str());
+ }
+ if (have_stored) {
+ fprintf(out, "stored = %ju bytes;\n", stored_bytes);
+ }
+ if (redis_key) {
+ fprintf(out, "key = \"%s\";\n", redis_key);
+ }
+}
+
enum rspamc_command_type {
RSPAMC_COMMAND_UNKNOWN = 0,
RSPAMC_COMMAND_CHECK,
.is_controller = FALSE,
.is_privileged = FALSE,
.need_input = TRUE,
- .command_output_func = rspamc_symbols_output},
+ .command_output_func = rspamc_neural_learn_output},
rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_ADD,
.name = "fuzzy_add",
local function add_cb(err, _)
if err then
rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
- rule.prefix, set.name, profile.redis_key, err)
+ rule.prefix, set.name, profile.redis_key, err)
else
rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
- rule.prefix, set.name, profile.redis_key)
+ rule.prefix, set.name, profile.redis_key)
end
end
lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- add_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rule.redis,
+ nil,
+ true, -- is write
+ add_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
return profile
profile = set.ann
else
lua_util.debugm(N, task, 'no ann loaded for %s:%s',
- rule.prefix, set.name)
+ rule.prefix, set.name)
end
else
lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
- rule.prefix, sid)
+ rule.prefix, sid)
end
if ann then
local function after_features(vec, meta)
if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
- rule.prefix, set.name)
+ rule.prefix, set.name)
vec = nil
end
local symscore = string.format('%.3f', score)
task:cache_set(rule.prefix .. '_neural_score', score)
lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
- rule.prefix, set.name, set.ann.version, symscore)
+ rule.prefix, set.name, set.ann.version, symscore)
if score > 0 then
local result = score
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
- rule.prefix, set.name, set.ann.version, symscore,
- spam_threshold)
+ rule.prefix, set.name, set.ann.version, symscore,
+ spam_threshold)
end
else
local result = -(score)
end
else
lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
- rule.prefix, set.name, set.ann.version, result,
- ham_threshold)
+ rule.prefix, set.name, set.ann.version, result,
+ ham_threshold)
end
end
end
end
end
+ -- If LLM provider is configured, suppress autotrain unless manual training requested
+ if not manual_train and rule.providers and #rule.providers > 0 then
+ for _, p in ipairs(rule.providers) do
+ if p.type == 'llm' then
+ lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header')
+ learn_spam = false
+ learn_ham = false
+ skip_reason = 'llm provider requires manual training'
+ break
+ end
+ end
+ end
+
if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
if not learn_spam then
skip_reason = string.format('score < spam_score: %f < %f',
- score, train_opts.spam_score)
+ score, train_opts.spam_score)
end
else
learn_spam = verdict == 'spam' or verdict == 'junk'
if not learn_spam then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
learn_ham = score <= train_opts.ham_score
if not learn_ham then
skip_reason = string.format('score > ham_score: %f > %f',
- score, train_opts.ham_score)
+ score, train_opts.ham_score)
end
else
learn_ham = verdict == 'ham'
if not learn_ham then
skip_reason = string.format('verdict: %s',
- verdict)
+ verdict)
end
end
else
local nspam, nham = data[1], data[2]
if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
- local vec
- if rule.providers and #rule.providers > 0 then
- -- Note: this training path remains sync for now; vectors are pushed when computed
- -- fall back to legacy vector; async training push will be added later
- vec = neural_common.result_to_vector(task, set)
- else
- vec = neural_common.result_to_vector(task, set)
- end
+ local function store_train_vec(vec)
+ if not vec then
+ lua_util.debugm(N, task, "no vector collected for training")
+ return
+ end
- local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+ local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
- local function learn_vec_cb(redis_err)
- if redis_err then
- rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ local function learn_vec_cb(redis_err)
+ if redis_err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
rule.prefix, set.name, redis_err)
- else
- lua_util.debugm(N, task,
+ else
+ lua_util.debugm(N, task,
"add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ end
end
- end
- lua_redis.redis_make_request(task,
+ lua_redis.redis_make_request(task,
rule.redis,
nil,
- true, -- is write
- learn_vec_cb, --callback
- 'SADD', -- command
+ true, -- is write
+ learn_vec_cb, --callback
+ 'SADD', -- command
{ target_key, str } -- arguments
- )
+ )
+ end
+
+ if rule.providers and #rule.providers > 0 then
+ -- Use async feature collection with providers, same as inference
+ neural_common.collect_features_async(task, rule, set, 'train', store_train_vec)
+ else
+ -- Traditional symbol-based vector
+ local vec = neural_common.result_to_vector(task, set)
+ store_train_vec(vec)
+ end
else
lua_util.debugm(N, task,
- "do not add %s train data for ANN rule " ..
- "%s:%s",
- learn_type, rule.prefix, set.name)
+ "do not add %s train data for ANN rule " ..
+ "%s:%s",
+ learn_type, rule.prefix, set.name)
end
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
+ rule.prefix, set.name, err)
elseif type(data) == 'string' then
-- nil return value
rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
- learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+ learn_type, rule.prefix, set.name, set.ann.redis_key, data)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
- 'please remove this key from Redis manually if you perform upgrade from the previous version',
- rule.prefix, set.name, set.ann.redis_key, type(data))
+ 'please remove this key from Redis manually if you perform upgrade from the previous version',
+ rule.prefix, set.name, set.ann.redis_key, type(data))
end
end
end
-- Need to create or load a profile corresponding to the current configuration
set.ann = new_ann_profile(task, rule, set, 0)
lua_util.debugm(N, task,
- 'requested new profile for %s, set.ann is missing',
- set.name)
+ 'requested new profile for %s, set.ann is missing',
+ set.name)
end
lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
- { task = task, is_write = false },
- vectors_len_cb,
- {
- set.ann.redis_key,
- })
+ { task = task, is_write = false },
+ vectors_len_cb,
+ {
+ set.ann.redis_key,
+ })
else
lua_util.debugm(N, task,
- 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+ 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
end
else
lua_util.debugm(N, task,
- 'do not push data to key %s: train condition not satisfied; reason: %s',
- (set.ann or {}).redis_key,
- skip_reason)
+ 'do not push data to key %s: train condition not satisfied; reason: %s',
+ (set.ann or {}).redis_key,
+ skip_reason)
end
end
local function do_train_ann(worker, ev_base, rule, set, ann_key)
local spam_elts = {}
local ham_elts = {}
+ lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key)
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
-- Unlock ANN on error
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- { ann_key, 'lock' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
)
else
-- Decompress and convert to numbers each training vector
spam_elts = process_training_vectors(data)
-- Now get ham vectors...
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_ham_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_ham_set' }
)
end
end
local function redis_lock_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
- ann_key, err)
+ ann_key, err)
elseif type(data) == 'number' and data == 1 then
-- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_spam_cb, --callback
- 'SMEMBERS', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_spam_set' }
)
rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
else
local lock_tm = tonumber(data[1])
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
- 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
- data[2], os.date('%c', lock_tm))
+ 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+ data[2], os.date('%c', lock_tm))
end
end
-- Check if we are already learning this network
if set.learning_spawned then
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
- ann_key)
+ ann_key)
return
end
-- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
-- ANN is locked by another host (or a process, meh)
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
- { ev_base = ev_base, is_write = true },
- redis_lock_cb,
- {
- ann_key,
- tostring(os.time()),
- tostring(math.max(10.0, rule.watch_interval * 2)),
- rspamd_util.get_hostname()
- })
+ { ev_base = ev_base, is_write = true },
+ redis_lock_cb,
+ {
+ ann_key,
+ tostring(os.time()),
+ tostring(math.max(10.0, rule.watch_interval * 2)),
+ rspamd_util.get_hostname()
+ })
end
-- This function loads new ann from Redis
local function data_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
- ann_key, err)
+ ann_key, err)
else
if type(data) == 'table' then
if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
if _err or not ann_data then
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
+ rule.prefix .. ':' .. set.name, ann_key, _err)
return
else
ann = rspamd_kann.load(ann_data)
end
-- Also update rank for the loaded ANN to avoid removal
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
)
rspamd_logger.infox(rspamd_config,
- 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[1], profile.version)
+ 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[1], profile.version)
else
rspamd_logger.errx(rspamd_config,
- 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
- rule.prefix, set.name, ann_key)
+ 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
end
end
else
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
local roc_thresholds = parser:get_object()
set.ann.roc_thresholds = roc_thresholds
rspamd_logger.infox(rspamd_config,
- 'loaded ROC thresholds for %s:%s; version=%s',
- rule.prefix, set.name, profile.version)
+ 'loaded ROC thresholds for %s:%s; version=%s',
+ rule.prefix, set.name, profile.version)
rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
end
end
-- We can use PCA
set.ann.pca = rspamd_tensor.load(pca_data)
rspamd_logger.infox(rspamd_config,
- 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[3], profile.version)
+ 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[3], profile.version)
else
-- no need in pca, why is it there?
rspamd_logger.warnx(rspamd_config,
- 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
- rule.prefix, set.name, ann_key)
+ 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+ rule.prefix, set.name, ann_key)
end
else
-- pca can be missing merely if we have no max_inputs
if rule.max_inputs then
rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
- rule.prefix, set.name, ann_key, _err)
+ rule.prefix, set.name, ann_key, _err)
set.ann.ann = nil
else
-- It is okay
end
else
lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
+ rule.prefix, set.name, ann_key)
end
end
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- data_cb, --callback
- 'HMGET', -- command
- { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
- { opaque_data = true }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ data_cb, --callback
+ 'HMGET', -- command
+ { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+ { opaque_data = true }
)
end
local my_symbols = set.symbols
local min_diff = math.huge
local sel_elt
+ lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s',
+ type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
for _, elt in fun.iter(profiles) do
if elt and elt.symbols then
if set.ann.version < sel_elt.version then
-- Load new ann
rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
end
else
-- We have some different ANN, so we need to compare distance
if set.ann.distance > min_diff then
-- Load more specific ANN
rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
else
lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
end
end
else
load_new_ann(rule, ev_base, set, sel_elt, min_diff)
end
end
+ if sel_elt then
+ lua_util.debugm(N, rspamd_config, 'process_existing_ann: selected profile version=%s key=%s', sel_elt.version,
+ sel_elt.redis_key)
+ else
+ lua_util.debugm(N, rspamd_config, 'process_existing_ann: no suitable profile found')
+ end
end
spam = 0,
ham = 0,
}
+ lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s',
+ type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
for _, elt in fun.iter(profiles) do
if elt and elt.symbols then
local ann_key = sel_elt.redis_key
lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
+ ann_key)
-- Create continuation closure
local redis_len_cb_gen = function(cont_cb, what, is_final)
return function(err, data)
if err then
rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
local ntrains = tonumber(data) or 0
lens[what] = ntrains
end
if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
else
-- Probabilistic mode, just ensure that at least one vector is okay
if min_len > 0 and max_len >= rule.train.max_trains then
lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
end
else
lua_util.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
cont_cb()
end
end
local function initiate_train()
rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s required learn vectors',
- ann_key, lens)
+ 'need to learn ANN %s after %s required learn vectors',
+ ann_key, lens)
+ lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key,
+ lens.spam or -1, lens.ham or -1)
do_train_ann(worker, ev_base, rule, set, ann_key)
end
-- Spam vector is OK, check ham vector length
local function check_ham_len()
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(initiate_train, 'ham', true), --callback
- 'SCARD', -- command
- { ann_key .. '_ham_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(initiate_train, 'ham', true), --callback
+ 'SCARD', -- command
+ { ann_key .. '_ham_set' }
)
end
lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(check_ham_len, 'spam', false), --callback
- 'SCARD', -- command
- { ann_key .. '_spam_set' }
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+ 'SCARD', -- command
+ { ann_key .. '_spam_set' }
)
end
end
local res, ucl_err = parser:parse_string(element)
if not res then
rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
- ucl_err)
+ ucl_err)
return nil
else
local profile = parser:get_object()
local function members_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
- err)
+ err)
set.can_store_vectors = true
elseif type(data) == 'table' then
- lua_util.debugm(N, cfg, '%s: process element %s:%s',
- what, rule.prefix, set.name)
+ lua_util.debugm(N, cfg, '%s: process element %s:%s (profiles=%s)',
+ what, rule.prefix, set.name, #data)
process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
set.can_store_vectors = true
+ else
+ lua_util.debugm(N, cfg, '%s: no profiles for %s:%s', what, rule.prefix, set.name)
+ set.can_store_vectors = true
end
end
-- Select the most appropriate to our profile but it should not differ by more
-- than 30% of symbols
lua_redis.redis_make_request_taskless(ev_base,
- cfg,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'ZREVRANGE', -- command
- { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+ cfg,
+ rule.redis,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'ZREVRANGE', -- command
+ { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
)
end
end -- Cycle over all settings
local function invalidate_cb(err, data)
if err then
rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
- err)
+ err)
elseif type(data) == 'table' then
for _, expired in ipairs(data) do
local profile = load_ann_profile(expired)
rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
- rule.prefix .. ':' .. set.name,
- profile.redis_key,
- profile.version)
+ rule.prefix .. ':' .. set.name,
+ profile.redis_key,
+ profile.version)
end
end
end
if type(set) == 'table' then
lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
- { ev_base = ev_base, is_write = true },
- invalidate_cb,
- { set.prefix, tostring(settings.max_profiles) })
+ { ev_base = ev_base, is_write = true },
+ invalidate_cb,
+ { set.prefix, tostring(settings.max_profiles) })
end
end
end
if verdict == 'passthrough' then
lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
- verdict, score)
+ verdict, score)
return
end
if score ~= score then
lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
- verdict)
+ verdict)
return
end
if rule_elt.max_inputs and not has_blas then
rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
- rule_elt.name, rule_elt.max_inputs)
+ rule_elt.name, rule_elt.max_inputs)
rule_elt.max_inputs = nil
end
end
if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
rspamd_logger.errx(rspamd_config,
- 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+ 'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
end
end
end
rspamd_config:add_on_load(function(cfg, ev_base, worker)
if worker:is_scanner() then
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
- 'try_load_ann')
- end)
+ function(_, _)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+ 'try_load_ann')
+ end)
end
if worker:is_primary_controller() then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- -- Clean old ANNs
- cleanup_anns(rule, cfg, ev_base)
- return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
- 'try_train_ann')
- end)
+ function(_, _)
+ -- Clean old ANNs
+ cleanup_anns(rule, cfg, ev_base)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+ 'try_train_ann')
+ end)
end
end)
end
--- /dev/null
+*** Settings ***
+Suite Setup Rspamd Redis Setup
+Suite Teardown Rspamd Redis Teardown
+Library Process
+Library ${RSPAMD_TESTDIR}/lib/rspamd.py
+Resource ${RSPAMD_TESTDIR}/lib/rspamd.robot
+Variables ${RSPAMD_TESTDIR}/lib/vars.py
+
+*** Variables ***
+${CONFIG} ${RSPAMD_TESTDIR}/configs/neural_llm.conf
+${SPAM_MSG} ${RSPAMD_TESTDIR}/messages/spam_message.eml
+${HAM_MSG} ${RSPAMD_TESTDIR}/messages/ham.eml
+${REDIS_SCOPE} Suite
+${RSPAMD_SCOPE} Suite
+${RSPAMD_URL_TLD} ${RSPAMD_TESTDIR}/../lua/unit/test_tld.dat
+
+*** Test Cases ***
+Train LLM-backed neural and verify
+ Run Dummy Llm
+
+ # Learn spam
+ ${result} = Run Rspamc -P secret -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} neural_learn:spam ${SPAM_MSG}
+ Check Rspamc ${result}
+
+ # Learn ham
+ ${result} = Run Rspamc -P secret -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} neural_learn:ham ${HAM_MSG}
+ Check Rspamc ${result}
+
+ Sleep 5s
+
+ # Check spam inference (dummy_llm returns ones vector for "spam" content)
+ Scan File ${SPAM_MSG} Settings={groups_enabled=["neural"]}
+ Expect Symbol NEURAL_SPAM
+
+ # Check ham inference (zeros vector)
+ Scan File ${HAM_MSG} Settings={groups_enabled=["neural"]}
+ Expect Symbol NEURAL_HAM
+
+ Dummy Llm Teardown
--- /dev/null
+options = {
+ url_tld = "{= env.URL_TLD =}"
+ pidfile = "{= env.TMPDIR =}/rspamd.pid"
+ lua_path = "{= env.INSTALLROOT =}/share/rspamd/lib/?.lua"
+ filters = [];
+ explicit_modules = ["settings"];
+}
+
+logging = {
+ type = "file",
+ level = "debug"
+ filename = "{= env.TMPDIR =}/rspamd.log"
+ log_usec = true;
+}
+metric = {
+ name = "default",
+ actions = {
+ reject = 100500,
+ add_header = 50500,
+ }
+ unknown_weight = 1
+}
+worker {
+ type = normal
+ bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_NORMAL =}"
+ count = 1
+ task_timeout = 10s;
+}
+worker {
+ type = controller
+ bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_CONTROLLER =}"
+ count = 1
+ secure_ip = ["127.0.0.1", "::1"];
+ stats_path = "{= env.TMPDIR =}/stats.ucl"
+}
+
+modules {
+ path = "{= env.TESTDIR =}/../../src/plugins/lua/"
+}
+
+lua = "{= env.TESTDIR =}/lua/test_coverage.lua";
+
+neural {
+ rules {
+ default {
+ train {
+ learning_rate = 0.001;
+ max_trains = 1;
+ max_iterations = 250;
+ }
+ symbol_spam = "NEURAL_SPAM";
+ symbol_ham = "NEURAL_HAM";
+ ann_expire = 86400;
+ watch_interval = 0.5;
+ providers = [{ type = "llm"; model = "dummy-embed"; url = "http://127.0.0.1:18080"; weight = 1.0; }];
+ fusion { normalization = "none"; }
+ roc_enabled = false;
+ }
+ }
+ allow_local = true;
+}
+
+redis {
+ servers = "{= env.REDIS_ADDR =}:{= env.REDIS_PORT =}";
+ expand_keys = true;
+}
+
+lua = "{= env.TESTDIR =}/lua/neural.lua";
--- /dev/null
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import json
+import sys
+from http.server import BaseHTTPRequestHandler, HTTPServer
+
+import dummy_killer
+
+PID = "/tmp/dummy_llm.pid"
+
+
+def make_embedding(text: str, dim: int = 32):
+ # Deterministic: if text contains 'SPAM' (case-insensitive) -> ones; else zeros
+ if 'spam' in text.lower():
+ return [1.0] * dim
+ return [0.0] * dim
+
+
+class EmbeddingHandler(BaseHTTPRequestHandler):
+ # OpenAI-like embeddings API
+ def do_POST(self):
+ length = int(self.headers.get('Content-Length', '0'))
+ raw = self.rfile.read(length) if length > 0 else b''
+ try:
+ data = json.loads(raw.decode('utf-8') or '{}')
+ except Exception:
+ data = {}
+
+ # Support both OpenAI ({input, model}) and Ollama ({prompt, model}) shapes
+ text = data.get('input') or data.get('prompt') or ''
+ # Optional dimension override for tests
+ dim = int(data.get('dim') or 32)
+ emb = make_embedding(text, dim)
+
+ if 'openai' in (self.headers.get('User-Agent') or '').lower() or True:
+ # Always reply in OpenAI-like format expected by neural provider
+ body = {
+ "data": [
+ {"embedding": emb}
+ ]
+ }
+ else:
+ body = {"embedding": emb}
+
+ reply = json.dumps(body).encode('utf-8')
+ self.send_response(200)
+ self.send_header('Content-Type', 'application/json')
+ self.send_header('Content-Length', str(len(reply)))
+ self.end_headers()
+ self.wfile.write(reply)
+
+ def log_message(self, fmt, *args):
+ # Keep test output quiet
+ return
+
+
+if __name__ == "__main__":
+ alen = len(sys.argv)
+ if alen > 1:
+ port = int(sys.argv[1])
+ else:
+ port = 18080
+ server = HTTPServer(("127.0.0.1", port), EmbeddingHandler)
+ dummy_killer.write_pid(PID)
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ pass
+ finally:
+ server.server_close()