local function compose_llm_settings(pcfg)
local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
- local llm_type = pcfg.type or gpt_settings.type or 'openai'
+ -- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys
+ local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai'
local model = pcfg.model or gpt_settings.model
local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
local url = pcfg.url
api_key = api_key,
cache_ttl = pcfg.cache_ttl or 86400,
cache_prefix = pcfg.cache_prefix or 'neural_llm',
- cache_hash_len = pcfg.cache_hash_len or 16,
- cache_use_hashing = pcfg.cache_use_hashing ~= false,
+ cache_hash_len = pcfg.cache_hash_len or 32,
+ cache_use_hashing = (pcfg.cache_use_hashing ~= false),
}
end
end
neural_common.register_provider('llm', {
- collect = function(task, ctx)
+ collect_async = function(task, ctx, cont)
local pcfg = ctx.config or {}
local llm = compose_llm_settings(pcfg)
if not llm.model then
rspamd_logger.debugm(N, task, 'llm provider missing model; skip')
- return nil
+ cont(nil)
+ return
end
local input_tbl = select_text(task)
if not input_tbl then
rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
- return nil
+ cont(nil)
+ return
end
-- Build request input string (text then optional subject), keeping rspamd_text intact
input_string = input_string .. "\nSubject: " .. input_tbl.subject
end
+ rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url),
+ tostring(#tostring(input_string)))
+
local body
if llm.type == 'openai' then
body = { model = llm.model, input = input_string }
body = { model = llm.model, prompt = input_string }
else
rspamd_logger.debugm(N, task, 'unsupported llm type: %s', llm.type)
- return nil
+ cont(nil)
+ return
end
- -- Redis cache: hash the final input string only (IUF is trivial here)
+ -- Redis cache: hash the final input string only
local cache_ctx = lua_cache.create_cache_context(neural_common.redis_params, {
cache_prefix = llm.cache_prefix,
cache_ttl = llm.cache_ttl,
cache_use_hashing = llm.cache_use_hashing,
}, N)
- local hasher = require 'rspamd_cryptobox_hash'
- local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', hasher.create(input_string):hex())
+ -- Use raw key and allow cache module to hash/shorten it per context
+ local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_string)
+
+ local function finish_with_vec(vec)
+ if type(vec) == 'table' and #vec > 0 then
+ local meta = { name = pcfg.name or 'llm', type = 'llm', dim = #vec, weight = ctx.weight or 1.0 }
+ rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec)
+ cont(vec, meta)
+ else
+ rspamd_logger.debugm(N, task, 'llm embedding result: empty')
+ cont(nil)
+ end
+ end
+
+ local function http_cb(err, code, resp, _)
+ if err then
+ rspamd_logger.debugm(N, task, 'llm http error: %s', err)
+ cont(nil)
+ return
+ end
+ if code ~= 200 or not resp then
+ rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
+ cont(nil)
+ return
+ end
+
+ local parser = ucl.parser()
+ local ok, perr = parser:parse_string(resp)
+ if not ok then
+ rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
+ cont(nil)
+ return
+ end
+ local parsed = parser:get_object()
+ local emb = extract_embedding(llm.type, parsed)
+ if type(emb) == 'table' then
+ lua_cache.cache_set(task, key, emb, cache_ctx)
+ finish_with_vec(emb)
+ else
+ rspamd_logger.debugm(N, task, 'llm embedding parse: no embedding field')
+ cont(nil)
+ end
+ end
local function do_request_and_cache()
local headers = { ['Content-Type'] = 'application/json' }
task = task,
method = 'POST',
use_gzip = true,
+ keepalive = true,
+ callback = http_cb,
}
- local function http_cb(err, code, resp, _)
- if err then
- rspamd_logger.debugm(N, task, 'llm http error: %s', err)
- return
- end
- if code ~= 200 or not resp then
- rspamd_logger.debugm(N, task, 'llm bad http code: %s', code)
- return
- end
-
- local parser = ucl.parser()
- local ok, perr = parser:parse_string(resp)
- if not ok then
- rspamd_logger.debugm(N, task, 'llm cannot parse reply: %s', perr)
- return
- end
- local parsed = parser:get_object()
- local emb = extract_embedding(llm.type, parsed)
- if type(emb) == 'table' then
- cache_ctx:set_cached(key, emb)
- neural_common.append_provider_vector(ctx, { provider = 'llm', vector = emb })
- end
- end
-
- rspamd_http.request(http_params, http_cb)
- end
-
- local cached = cache_ctx:get_cached(key)
- if type(cached) == 'table' then
- neural_common.append_provider_vector(ctx, { provider = 'llm', vector = cached })
- return
+ rspamd_http.request(http_params)
end
- do_request_and_cache()
+ -- Use async cache API
+ lua_cache.cache_get(task, key, cache_ctx, llm.timeout or 2.0,
+ function()
+ -- Uncached path
+ do_request_and_cache()
+ end,
+ function(_, err, data)
+ if data and type(data) == 'table' then
+ finish_with_vec(data)
+ else
+ do_request_and_cache()
+ end
+ end)
end,
})
-- - Rule: rule name (optional, default 'default')
local function handle_learn_message(task, conn)
lua_util.debugm(N, task, 'controller.neural: learn_message called')
+
+ -- Ensure the message is parsed so LLM providers can access text parts
+ local ok_parse = task:process_message()
+ if not ok_parse then
+ lua_util.debugm(N, task, 'controller.neural: cannot process message MIME, abort')
+ conn:send_error(400, 'cannot parse message for learning')
+ return
+ end
+
local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
if not cls then
conn:send_error(400, 'missing class header (ANN-Train or Class)')
end
local set = neural_common.get_rule_settings(task, rule)
- if not set or not set.ann or not set.ann.redis_key then
+ if not set then
+ lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set',
+ rule_name)
+ for sid, s in pairs(rule.settings or {}) do
+ if type(s) == 'table' then
+ set = s
+ set.name = set.name or sid
+ break
+ end
+ end
+ end
+
+ -- Derive redis base key even if ANN not yet initialized
+ local redis_base
+ if set and set.ann and set.ann.redis_key then
+ redis_base = set.ann.redis_key
+ elseif set then
+ local ok, prefix = pcall(neural_common.redis_ann_prefix, rule, set.name)
+ if ok and prefix then
+ redis_base = prefix
+ lua_util.debugm(N, task, 'controller.neural: derived redis base key for rule=%s set=%s -> %s', rule_name, set.name,
+ redis_base)
+ end
+ end
+
+ if not set or not redis_base then
+ lua_util.debugm(N, task, 'controller.neural: invalid set or redis key for learning; set=%s ann=%s',
+ tostring(set ~= nil), set and tostring(set.ann ~= nil) or 'nil')
conn:send_error(400, 'invalid rule settings for learning')
return
end
local function after_collect(vec)
lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
if not vec then
- vec = neural_common.result_to_vector(task, set)
+ if rule.providers and #rule.providers > 0 then
+ lua_util.debugm(N, task,
+ 'controller.neural: no vector from providers; skip training to keep dimensions consistent')
+ conn:send_error(400, 'no vector collected from providers')
+ return
+ else
+ vec = neural_common.result_to_vector(task, set)
+ end
end
if type(vec) ~= 'table' then
return
end
+ -- Preview vector for debugging
+ local function preview_vector(v)
+ local n = #v
+ local limit = math.min(n, 8)
+ local parts = {}
+ for i = 1, limit do
+ parts[#parts + 1] = string.format('%.4f', tonumber(v[i]) or 0)
+ end
+ return n, table.concat(parts, ',')
+ end
+
+ local vlen, vhead = preview_vector(vec)
+ lua_util.debugm(N, task, 'controller.neural: vector size=%s head=[%s]', vlen, vhead)
+
local compressed = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = string.format('%s_%s_set', set.ann.redis_key, learn_type)
+ local target_key = string.format('%s_%s_set', redis_base, learn_type)
local function learn_vec_cb(redis_err)
if redis_err then