|----------|---------|-------------|
| `EMBEDDING_MODEL` | `BAAI/bge-small-en-v1.5` | FastEmbed model name |
| `EMBEDDING_PORT` | `8080` | Port to listen on |
-| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind to |
+| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind |
### Command Line Arguments
EMBEDDING_PORT: Port number (default: 8080)
EMBEDDING_HOST: Host to bind (default: 0.0.0.0)
"""
-
import argparse
import logging
import os
model_name: str = DEFAULT_MODEL
model_dim: int = 0
-
# Request/Response models
class OllamaEmbeddingRequest(BaseModel):
"""Ollama-compatible embedding request."""
dimensions: int
uptime_seconds: float
-
# Startup time for uptime calculation
startup_time: float = 0.0
logger.info(f"Loading embedding model: {model_name}")
start = time.time()
- try:
- model = TextEmbedding(model_name)
- # Get embedding dimension from a test inference
- test_embed = list(model.embed(["test"]))[0]
- model_dim = len(test_embed)
- elapsed = time.time() - start
- logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
- startup_time = time.time()
- except Exception as e:
- logger.error(f"Failed to load model: {e}")
- raise
+ model = TextEmbedding(model_name)
+ test_embed = list(model.embed(["test"]))[0]
+ model_dim = len(test_embed)
+ elapsed = time.time() - start
+ logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
+ startup_time = time.time()
yield
end
-- Compute a stable digest for providers configuration
-local function providers_config_digest(providers_cfg)
+local function providers_config_digest(providers_cfg, rule)
if not providers_cfg then return nil end
-- Normalize minimal subset of fields to keep digest stable across equivalent configs
- local norm = {}
+ local norm = { providers = {} }
+
+ local fusion = rule and rule.fusion or nil
+ if rule then
+ local effective_fusion = {
+ normalization = (fusion and fusion.normalization) or 'none',
+ include_meta = fusion and fusion.include_meta,
+ meta_weight = fusion and fusion.meta_weight,
+ per_provider_pca = fusion and fusion.per_provider_pca,
+ }
+ if effective_fusion.include_meta == nil then
+ effective_fusion.include_meta = true
+ end
+ if effective_fusion.meta_weight == nil then
+ effective_fusion.meta_weight = 1.0
+ end
+ if effective_fusion.per_provider_pca == nil then
+ effective_fusion.per_provider_pca = false
+ end
+ norm.fusion = effective_fusion
+ end
+
+ if rule and rule.max_inputs then
+ norm.max_inputs = rule.max_inputs
+ end
+
+ local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
+
for i, p in ipairs(providers_cfg) do
- norm[i] = {
- type = p.type,
- name = p.name,
+ local ptype = p.type or p.name or 'unknown'
+ local entry = {
+ type = ptype,
weight = p.weight or 1.0,
dim = p.dim,
}
+
+ if ptype == 'llm' then
+ local llm_type = p.llm_type or p.api or p.backend or gpt_settings.type
+ local model = p.model or gpt_settings.model
+ local max_tokens = p.max_tokens
+ if not max_tokens and gpt_settings.model_parameters and model then
+ local model_cfg = gpt_settings.model_parameters[model] or {}
+ max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens
+ end
+ if not max_tokens then
+ max_tokens = gpt_settings.max_tokens
+ end
+
+ entry.llm_type = llm_type
+ entry.model = model
+ entry.max_tokens = max_tokens
+ end
+
+ norm.providers[i] = entry
end
return lua_util.table_digest(norm)
end
local providers_cfg = rule.providers
if not providers_cfg or #providers_cfg == 0 then
if rule.disable_symbols_input then
- cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg) })
+ cb(nil, { providers = {}, total_dim = 0, digest = providers_config_digest(providers_cfg, rule) })
return
end
local prov = get_provider('symbols')
cb(#fused > 0 and fused or nil, {
providers = build_providers_meta(metas) or metas,
total_dim = #fused,
- digest = providers_config_digest(providers_cfg),
+ digest = providers_config_digest(providers_cfg, rule),
})
end)
return
providers = build_providers_meta({ meta }) or { meta },
total_dim = #fused,
digest = providers_config_digest(
- providers_cfg)
+ providers_cfg, rule)
})
return
end
local meta = {
providers = build_providers_meta(metas) or metas,
total_dim = #fused,
- digest = providers_config_digest(providers_cfg),
+ digest = providers_config_digest(providers_cfg, rule),
}
if #fused == 0 then
cb(nil, meta)
digest = params.set.digest,
redis_key = params.set.ann.redis_key,
version = version,
- providers_digest = providers_config_digest(params.rule.providers),
+ providers_digest = providers_config_digest(params.rule.providers, params.rule),
}
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
local N = "neural.llm"
-local function select_text(task)
- local input_tbl = llm_common.build_llm_input(task)
+local function select_text(task, llm)
+ local opts = {}
+ if llm and llm.max_tokens then
+ opts.max_tokens = llm.max_tokens
+ end
+ local input_tbl = llm_common.build_llm_input(task, opts)
return input_tbl
end
-- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys
local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai'
local model = pcfg.model or gpt_settings.model
+ local model_params = gpt_settings.model_parameters or {}
+ local model_cfg = model and model_params[model] or {}
+ local max_tokens = pcfg.max_tokens
+ if not max_tokens then
+ max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens or gpt_settings.max_tokens
+ end
local timeout = pcfg.timeout or gpt_settings.timeout or 2.0
local url = pcfg.url
local api_key = pcfg.api_key or gpt_settings.api_key
return {
type = llm_type,
model = model,
+ max_tokens = max_tokens,
timeout = timeout,
url = url,
api_key = api_key,
}
end
+local function normalize_cache_key_input(input_string)
+ if type(input_string) == 'userdata' then
+ return input_string:str()
+ end
+ return tostring(input_string)
+end
+
local function extract_embedding(llm_type, parsed)
if llm_type == 'openai' then
-- { data = [ { embedding = [...] } ] }
end
end
- local input_tbl = select_text(task)
+ local input_tbl = select_text(task, llm)
if not input_tbl then
rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
cont(nil)
input_string = input_string .. "\nSubject: " .. input_tbl.subject
end
+ local input_key = normalize_cache_key_input(input_string)
rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url),
- tostring(#tostring(input_string)))
+ tostring(#input_key))
local body
if llm.type == 'openai' then
}, N)
-- Use raw key and allow cache module to hash/shorten it per context
- local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_string)
+ local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_key)
local function finish_with_vec(vec)
if type(vec) == 'table' and #vec > 0 then
- local meta = { name = pcfg.name or 'llm', type = 'llm', dim = #vec, weight = ctx.weight or 1.0 }
+ local meta = {
+ name = pcfg.name or 'llm',
+ type = 'llm',
+ dim = #vec,
+ weight = ctx.weight or 1.0,
+ model = llm.model,
+ provider = llm.type,
+ }
rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec)
cont(vec, meta)
else
end
redis.call('HDEL', KEYS[1], 'lock')
redis.call('HDEL', KEYS[7], 'lock')
+redis.call('HSET', KEYS[7], 'obsolete', '1')
+redis.call('EXPIRE', KEYS[7], 600)
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
-- expire in 10m, to not face race condition with other rspamd replicas refill deleted keys
redis.call('EXPIRE', KEYS[7] .. '_spam_set', 600)
-- returns nspam,nham (or nil if locked)
local prefix = KEYS[1]
+local obsolete = redis.call('HGET', prefix, 'obsolete')
+if obsolete then
+ return 'obsolete'
+end
local locked = redis.call('HGET', prefix, 'lock')
if locked then
local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
local E = {}
local N = 'neural'
+local function get_request_header(task, name)
+ local hdr = task:get_request_header(name)
+ if type(hdr) == 'table' then
+ hdr = hdr[1]
+ end
+ if hdr then
+ return tostring(hdr)
+ end
+ return nil
+end
+
-- Controller neural plugin
local learn_request_schema = T.table({
return
end
- local cls = task:get_request_header('ANN-Train') or task:get_request_header('Class')
+ local cls = get_request_header(task, 'ANN-Train') or get_request_header(task, 'Class')
if not cls then
conn:send_error(400, 'missing class header (ANN-Train or Class)')
return
return
end
- local rule_name = task:get_request_header('Rule') or 'default'
+ local rule_name = get_request_header(task, 'Rule') or 'default'
local rule = neural_common.settings.rules[rule_name]
if not rule then
conn:send_error(400, 'unknown rule')
version = version,
digest = set.digest,
distance = 0,
- providers_digest = neural_common.providers_config_digest(rule.providers),
+ providers_digest = neural_common.providers_config_digest(rule.providers, rule),
}
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
version = version,
digest = set.digest,
distance = 0, -- Since we are using our own profile
- providers_digest = neural_common.providers_config_digest(rule.providers),
+ providers_digest = neural_common.providers_config_digest(rule.providers, rule),
}
local ucl = require "ucl"
local spam_threshold = 0
if rule.spam_score_threshold then
spam_threshold = rule.spam_score_threshold
- elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ elseif rule.roc_enabled and set.ann.roc_thresholds then
spam_threshold = set.ann.roc_thresholds[1]
end
local ham_threshold = 0
if rule.ham_score_threshold then
ham_threshold = rule.ham_score_threshold
- elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ elseif rule.roc_enabled and set.ann.roc_thresholds then
ham_threshold = set.ann.roc_thresholds[2]
end
end
end
+local function get_ann_train_header(task)
+ local hdr = task:get_request_header('ANN-Train')
+ if type(hdr) == 'table' then
+ hdr = hdr[1]
+ end
+ if hdr then
+ return tostring(hdr):lower()
+ end
+ return nil
+end
+
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
local learn_spam, learn_ham
-- First, honor explicit manual training header if present
do
- local hdr = task:get_request_header('ANN-Train')
- if hdr then
- local hv = tostring(hdr):lower()
- lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode', hv)
+ local hv = get_ann_train_header(task)
+ if hv then
+ lua_util.debugm(N, task, 'found ANN-Train header, enable manual train mode: %s', hv)
if hv == 'spam' then
learn_spam = true
manual_train = true