From: hunter-nl Date: Wed, 10 Sep 2025 14:27:41 +0000 (+0200) Subject: Refactor model parameters and response handling X-Git-Tag: 3.13.0~3^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=43dc63255d3e296c857801baee3022252a54e2e9;p=thirdparty%2Frspamd.git Refactor model parameters and response handling Model parameters now applied to all models for both AI types. --- diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index b9d8f0d53a..fcd6ab5179 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -29,19 +29,19 @@ if confighelp then api_key = "xxx"; # Model name model = "gpt-5-mini"; # or parallel model requests [ "gpt-5-mini", "gpt-4o-mini" ]; - # Per-model parameters - model_parameters = { - "gpt-5-mini" = { - max_completion_tokens = 1000, - }, - "gpt-5-nano" = { - max_completion_tokens = 1000, - }, - "gpt-4o-mini" = { - max_tokens = 1000, - temperature = 0.0, - } - }; + # Per-model parameters + model_parameters = { + "gpt-5-mini" = { + max_completion_tokens = 1000, + }, + "gpt-5-nano" = { + max_completion_tokens = 1000, + }, + "gpt-4o-mini" = { + max_tokens = 1000, + temperature = 0.0, + } + }; # Timeout for requests timeout = 10s; # Prompt for the model (use default if not set) @@ -211,7 +211,9 @@ local function default_condition(task) end -- Unified LLM input building (subject/from/urls/body one-line) - local input_tbl, sel_part = llm_common.build_llm_input(task, { max_tokens = settings.max_tokens }) + local model_cfg = settings.model_parameters[settings.model] or {} + local max_tokens = model_cfg.max_completion_tokens or model_cfg.max_tokens or 1000 + local input_tbl, sel_part = llm_common.build_llm_input(task, { max_tokens = max_tokens }) if not sel_part then return false, 'no text part found' end @@ -343,7 +345,7 @@ local function clean_reply_line(line) return lua_util.str_trim(line):gsub("^%d%.%s+", "") end --- Assume that we have 3 lines: probability, reason, additional symbols +-- Assume that we have 3 lines: probability, reason, additional categories local function default_openai_plain_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) @@ -488,7 +490,7 @@ local function default_ollama_json_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score, reply.reason + return spam_score, reply.reason, {} end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) @@ -659,7 +661,7 @@ local function openai_check(task, content, sel_part) return end - local reply, reason = settings.reply_conversion(task, body) + local reply, reason, categories = settings.reply_conversion(task, body) results[i].model = model @@ -667,6 +669,10 @@ local function openai_check(task, content, sel_part) results[i].success = true results[i].probability = reply results[i].reason = reason + + if categories then + results[i].categories = categories + end end check_consensus_and_insert_results(task, results, sel_part) @@ -691,8 +697,6 @@ local function openai_check(task, content, sel_part) local body_base = { stream = false, - max_tokens = settings.max_tokens, - temperature = settings.temperature, messages = { { role = 'system', @@ -706,7 +710,6 @@ local function openai_check(task, content, sel_part) settings.model = { settings.model } end - upstream = settings.upstreams:get_upstream_round_robin() for idx, model in ipairs(settings.model) do results[idx] = { success = false, @@ -729,7 +732,8 @@ local function openai_check(task, content, sel_part) end body.model = model - + + upstream = settings.upstreams:get_upstream_round_robin() local http_params = { url = settings.url, mime_type = 'application/json', @@ -758,9 +762,9 @@ local function ollama_check(task, content, sel_part) local upstream local results = {} - local function gen_reply_closure(model, idx) + local function gen_reply_closure(model, i) return function(err, code, body) - results[idx].checked = true + results[i].checked = true if err then rspamd_logger.errx(task, '%s: request failed: %s', model, err) upstream:fail() @@ -775,14 +779,17 @@ local function ollama_check(task, content, sel_part) return end - local reply, reason = settings.reply_conversion(task, body) + local reply, reason, categories = settings.reply_conversion(task, body) - results[idx].model = model + results[i].model = model if reply then - results[idx].success = true - results[idx].probability = reply - results[idx].reason = reason + results[i].success = true + results[i].probability = reply + results[i].reason = reason + if categories then + results[i].categories = categories + end end check_consensus_and_insert_results(task, results, sel_part) @@ -808,11 +815,12 @@ local function ollama_check(task, content, sel_part) settings.model = { settings.model } end - local body = { + local body_base = { stream = false, model = settings.model, - max_tokens = settings.max_tokens, - temperature = settings.temperature, + -- should not in body_base + -- max_tokens = settings.max_tokens, + -- temperature = settings.temperature, messages = { { role = 'system', @@ -822,16 +830,27 @@ local function ollama_check(task, content, sel_part) } } - for i, model in ipairs(settings.model) do + for idx, model in ipairs(settings.model) do + results[idx] = { + success = false, + checked = false + } + -- Fresh body for each model + local body = lua_util.deepcopy(body_base) + + -- Merge model-specific parameters into body + local params = settings.model_parameters[model] + if params then + for k, v in pairs(params) do + body[k] = v + end + end + -- Conditionally add response_format if settings.include_response_format then body.response_format = { type = "json_object" } end - - results[i] = { - success = false, - checked = false - } + body.model = model upstream = settings.upstreams:get_upstream_round_robin() @@ -840,7 +859,7 @@ local function ollama_check(task, content, sel_part) mime_type = 'application/json', timeout = settings.timeout, log_obj = task, - callback = gen_reply_closure(model, i), + callback = gen_reply_closure(model, idx), keepalive = true, body = ucl.to_format(body, 'json-compact', true), task = task, @@ -848,7 +867,9 @@ local function ollama_check(task, content, sel_part) use_gzip = true, } - rspamd_http.request(http_params) + if not rspamd_http.request(http_params) then + results[idx].checked = true + end end end @@ -982,7 +1003,7 @@ if opts then "Output ONLY 3 lines:\n" .. "1. Numeric score (0.00-1.00)\n" .. "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" .. - "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') + "3. Empty line or mention ONLY the primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') else settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. "FROM and url domains. Evaluate spam probability (0-1). " ..