]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add language-based model/URL selection for LLM embeddings
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 19 Jan 2026 09:29:44 +0000 (09:29 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 19 Jan 2026 09:29:44 +0000 (09:29 +0000)
Support language-specific embedding models via language_models config:
- Shorthand: language_models = { ru = "model-name" }
- Full config: language_models = { ru = { model, url, api_key } }

Uses get_displayed_text_part() for language detection.
Include language in cache key for proper separation.

lualib/plugins/neural/providers/llm.lua

index 60dea419f5c8f8c6ac01e5c76254e20bbd8a049c..9709c231edebe9ccc76978aa9c78adf2f67b5ad8 100644 (file)
@@ -10,6 +10,7 @@ local ucl = require "ucl"
 local neural_common = require "plugins/neural"
 local lua_cache = require "lua_cache"
 local llm_common = require "llm_common"
+local lua_mime = require "lua_mime"
 
 local N = "neural.llm"
 
@@ -17,7 +18,19 @@ local function select_text(task, opts)
   return llm_common.build_llm_input(task, opts)
 end
 
-local function compose_llm_settings(pcfg)
+-- Detect primary language from the displayed text part
+local function detect_language(task)
+  local part = lua_mime.get_displayed_text_part(task)
+  if part then
+    local lang = part:get_language()
+    if lang and lang ~= '' then
+      return lang
+    end
+  end
+  return nil
+end
+
+local function compose_llm_settings(pcfg, language)
   local gpt_settings = rspamd_config:get_all_opt('gpt') or {}
   -- Provider identity is pcfg.type=='llm'; backend type is specified via one of these keys
   local llm_type = pcfg.llm_type or pcfg.api or pcfg.backend or gpt_settings.type or 'openai'
@@ -32,6 +45,31 @@ local function compose_llm_settings(pcfg)
   local url = pcfg.url
   local api_key = pcfg.api_key or gpt_settings.api_key
 
+  -- Language-specific model/URL selection
+  -- Config format: language_models = { en = { model = "...", url = "..." }, ru = { model = "..." }, ... }
+  -- Or shorthand: language_models = { en = "model-name", ru = "model-name", ... }
+  local language_models = pcfg.language_models
+  if language and language_models then
+    local lang_cfg = language_models[language]
+    if lang_cfg then
+      if type(lang_cfg) == 'string' then
+        -- Shorthand: just model name
+        model = lang_cfg
+      elseif type(lang_cfg) == 'table' then
+        -- Full config: { model = "...", url = "...", api_key = "..." }
+        if lang_cfg.model then
+          model = lang_cfg.model
+        end
+        if lang_cfg.url then
+          url = lang_cfg.url
+        end
+        if lang_cfg.api_key then
+          api_key = lang_cfg.api_key
+        end
+      end
+    end
+  end
+
   if not url then
     if llm_type == 'openai' then
       url = 'https://api.openai.com/v1/embeddings'
@@ -85,7 +123,10 @@ end
 neural_common.register_provider('llm', {
   collect_async = function(task, ctx, cont)
     local pcfg = ctx.config or {}
-    local llm = compose_llm_settings(pcfg)
+
+    -- Detect language from displayed text part for model/URL selection
+    local language = detect_language(task)
+    local llm = compose_llm_settings(pcfg, language)
 
     if not llm.model then
       rspamd_logger.debugm(N, task, 'llm provider missing model; skip')
@@ -117,8 +158,8 @@ neural_common.register_provider('llm', {
     end
 
     local input_key = normalize_cache_key_input(input_string)
-    rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s len=%s', tostring(llm.model), tostring(llm.url),
-      tostring(#input_key))
+    rspamd_logger.debugm(N, task, 'llm embedding request: model=%s url=%s lang=%s len=%s',
+      tostring(llm.model), tostring(llm.url), tostring(language or 'unknown'), tostring(#input_key))
 
     local body
     if llm.type == 'openai' then
@@ -141,7 +182,8 @@ neural_common.register_provider('llm', {
     }, N)
 
     -- Use raw key and allow cache module to hash/shorten it per context
-    local key = string.format('%s:%s:%s', llm.type, llm.model or 'model', input_key)
+    -- Include language in cache key for proper separation
+    local key = string.format('%s:%s:%s:%s', llm.type, llm.model or 'model', language or 'unk', input_key)
 
     local function finish_with_vec(vec)
       if type(vec) == 'table' and #vec > 0 then
@@ -152,8 +194,9 @@ neural_common.register_provider('llm', {
           weight = ctx.weight or 1.0,
           model = llm.model,
           provider = llm.type,
+          language = language,
         }
-        rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s', #vec)
+        rspamd_logger.debugm(N, task, 'llm embedding result: dim=%s lang=%s', #vec, language or 'unknown')
         cont(vec, meta)
       else
         rspamd_logger.debugm(N, task, 'llm embedding result: empty')