]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Refactor model parameters and response handling
authorhunter-nl <junobox@gmail.com>
Wed, 10 Sep 2025 14:27:41 +0000 (16:27 +0200)
committerGitHub <noreply@github.com>
Wed, 10 Sep 2025 14:27:41 +0000 (16:27 +0200)
Model parameters now applied to all models for both AI types.

src/plugins/lua/gpt.lua

index b9d8f0d53aabe49f4258ffeb75e05d45d73e0d2b..fcd6ab5179c9269e3903e0b5785bd6f5e4d71c5d 100644 (file)
@@ -29,19 +29,19 @@ if confighelp then
   api_key = "xxx";
   # Model name
   model = "gpt-5-mini"; # or parallel model requests [ "gpt-5-mini", "gpt-4o-mini" ];
-       # Per-model parameters
-       model_parameters = {
-               "gpt-5-mini" = {
-                       max_completion_tokens = 1000,
-               },
-               "gpt-5-nano" = {
-                       max_completion_tokens = 1000,
-               },
-               "gpt-4o-mini" = {
-                       max_tokens = 1000,
-                       temperature = 0.0,
-               }
-       };
+  # Per-model parameters
+  model_parameters = {
+    "gpt-5-mini" = {
+      max_completion_tokens = 1000,
+    },
+    "gpt-5-nano" = {
+      max_completion_tokens = 1000,
+    },
+    "gpt-4o-mini" = {
+      max_tokens = 1000,
+      temperature = 0.0,
+    }
+  };
   # Timeout for requests
   timeout = 10s;
   # Prompt for the model (use default if not set)
@@ -211,7 +211,9 @@ local function default_condition(task)
   end
 
   -- Unified LLM input building (subject/from/urls/body one-line)
-  local input_tbl, sel_part = llm_common.build_llm_input(task, { max_tokens = settings.max_tokens })
+  local model_cfg = settings.model_parameters[settings.model] or {}
+  local max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens or 1000
+  local input_tbl, sel_part = llm_common.build_llm_input(task, { max_tokens = max_tokens })
   if not sel_part then
     return false, 'no text part found'
   end
@@ -343,7 +345,7 @@ local function clean_reply_line(line)
   return lua_util.str_trim(line):gsub("^%d%.%s+", "")
 end
 
--- Assume that we have 3 lines: probability, reason, additional symbols
+-- Assume that we have 3 lines: probability, reason, additional categories
 local function default_openai_plain_conversion(task, input)
   local parser = ucl.parser()
   local res, err = parser:parse_string(input)
@@ -488,7 +490,7 @@ local function default_ollama_json_conversion(task, input)
       rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
     end
 
-    return spam_score, reply.reason
+    return spam_score, reply.reason, {}
   end
 
   rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
@@ -659,7 +661,7 @@ local function openai_check(task, content, sel_part)
         return
       end
 
-      local reply, reason = settings.reply_conversion(task, body)
+      local reply, reason, categories = settings.reply_conversion(task, body)
 
       results[i].model = model
 
@@ -667,6 +669,10 @@ local function openai_check(task, content, sel_part)
         results[i].success = true
         results[i].probability = reply
         results[i].reason = reason
+
+        if categories then
+          results[i].categories = categories
+        end
       end
 
       check_consensus_and_insert_results(task, results, sel_part)
@@ -691,8 +697,6 @@ local function openai_check(task, content, sel_part)
 
   local body_base = {
     stream = false,
-    max_tokens = settings.max_tokens,
-    temperature = settings.temperature,
     messages = {
       {
         role = 'system',
@@ -706,7 +710,6 @@ local function openai_check(task, content, sel_part)
     settings.model = { settings.model }
   end
 
-  upstream = settings.upstreams:get_upstream_round_robin()
   for idx, model in ipairs(settings.model) do
     results[idx] = {
       success = false,
@@ -729,7 +732,8 @@ local function openai_check(task, content, sel_part)
     end
 
     body.model = model
-
+    
+    upstream = settings.upstreams:get_upstream_round_robin()
     local http_params = {
       url = settings.url,
       mime_type = 'application/json',
@@ -758,9 +762,9 @@ local function ollama_check(task, content, sel_part)
   local upstream
   local results = {}
 
-  local function gen_reply_closure(model, idx)
+  local function gen_reply_closure(model, i)
     return function(err, code, body)
-      results[idx].checked = true
+      results[i].checked = true
       if err then
         rspamd_logger.errx(task, '%s: request failed: %s', model, err)
         upstream:fail()
@@ -775,14 +779,17 @@ local function ollama_check(task, content, sel_part)
         return
       end
 
-      local reply, reason = settings.reply_conversion(task, body)
+      local reply, reason, categories = settings.reply_conversion(task, body)
 
-      results[idx].model = model
+      results[i].model = model
 
       if reply then
-        results[idx].success = true
-        results[idx].probability = reply
-        results[idx].reason = reason
+        results[i].success = true
+        results[i].probability = reply
+        results[i].reason = reason
+        if categories then
+          results[i].categories = categories
+        end
       end
 
       check_consensus_and_insert_results(task, results, sel_part)
@@ -808,11 +815,12 @@ local function ollama_check(task, content, sel_part)
     settings.model = { settings.model }
   end
 
-  local body = {
+  local body_base = {
     stream = false,
     model = settings.model,
-    max_tokens = settings.max_tokens,
-    temperature = settings.temperature,
+    -- should not in body_base
+    -- max_tokens = settings.max_tokens,
+    -- temperature = settings.temperature,
     messages = {
       {
         role = 'system',
@@ -822,16 +830,27 @@ local function ollama_check(task, content, sel_part)
     }
   }
 
-  for i, model in ipairs(settings.model) do
+  for idx, model in ipairs(settings.model) do
+    results[idx] = {
+      success = false,
+      checked = false
+    }
+    -- Fresh body for each model
+    local body = lua_util.deepcopy(body_base)
+    
+    -- Merge model-specific parameters into body
+    local params = settings.model_parameters[model]
+    if params then
+      for k, v in pairs(params) do
+        body[k] = v
+      end
+    end
+
     -- Conditionally add response_format
     if settings.include_response_format then
       body.response_format = { type = "json_object" }
     end
-
-    results[i] = {
-      success = false,
-      checked = false
-    }
+   
     body.model = model
 
     upstream = settings.upstreams:get_upstream_round_robin()
@@ -840,7 +859,7 @@ local function ollama_check(task, content, sel_part)
       mime_type = 'application/json',
       timeout = settings.timeout,
       log_obj = task,
-      callback = gen_reply_closure(model, i),
+      callback = gen_reply_closure(model, idx),
       keepalive = true,
       body = ucl.to_format(body, 'json-compact', true),
       task = task,
@@ -848,7 +867,9 @@ local function ollama_check(task, content, sel_part)
       use_gzip = true,
     }
 
-    rspamd_http.request(http_params)
+    if not rspamd_http.request(http_params) then
+      results[idx].checked = true
+    end
   end
 end
 
@@ -982,7 +1003,7 @@ if opts then
           "Output ONLY 3 lines:\n" ..
           "1. Numeric score (0.00-1.00)\n" ..
           "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" ..
-          "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ')
+          "3. Empty line or mention ONLY the primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ')
     else
       settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
           "FROM and url domains. Evaluate spam probability (0-1). " ..