]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Support LLM models consensus 5320/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 27 Jan 2025 19:19:28 +0000 (19:19 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 27 Jan 2025 19:19:28 +0000 (19:19 +0000)
src/plugins/lua/gpt.lua

index e4a77c6dd873ffa389ef31423fb920ae63794c38..4888eaa191c00e830ac9038fb8a98b303dae6a52 100644 (file)
@@ -319,6 +319,47 @@ local function ollama_conversion(task, input)
   return
 end
 
+local function check_consensus(task, results)
+  for _, result in ipairs(results) do
+    if not result.checked then
+      return
+    end
+  end
+
+  local nspam, nham = 0, 0
+  local max_spam_prob, max_ham_prob = 0, 0
+
+  for _, result in ipairs(results) do
+    if result.success then
+      if result.probability > 0.5 then
+        nspam = nspam + 1
+        max_spam_prob = math.max(max_spam_prob, result.probability)
+        lua_util.debugm(N, task, "model: %s; spam: %s", result.model, result.probability)
+      else
+        nham = nham + 1
+        max_ham_prob = math.min(max_ham_prob, result.probability)
+        lua_util.debugm(N, task, "model: %s; ham: %s", result.model, result.probability)
+      end
+    end
+  end
+
+  if nspam > nham and max_spam_prob > 0.75 then
+    task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
+    if settings.autolearn then
+      task:set_flag("learn_spam")
+    end
+  elseif nham > nspam and max_ham_prob < 0.25 then
+    task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
+    if settings.autolearn then
+      task:set_flag("learn_ham")
+    end
+  else
+    -- No consensus
+    lua_util.debugm(N, task, "no consensus")
+  end
+
+end
+
 local function get_meta_llm_content(task)
   local url_content = "Url domains: no urls found"
   if task:has_urls() then
@@ -353,40 +394,36 @@ local function default_llm_check(task)
 
   local upstream
 
-  local function on_reply(err, code, body)
+  local results = {}
 
-    if err then
-      rspamd_logger.errx(task, 'request failed: %s', err)
-      upstream:fail()
-      return
-    end
+  local function gen_reply_closure(model, idx)
+    return function(err, code, body)
+      results[idx].checked = true
+      if err then
+        rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+        upstream:fail()
+        check_consensus(task, results)
+        return
+      end
 
-    upstream:ok()
-    lua_util.debugm(N, task, "got reply: %s", body)
-    if code ~= 200 then
-      rspamd_logger.errx(task, 'bad reply: %s', body)
-      return
-    end
+      upstream:ok()
+      lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+      if code ~= 200 then
+        rspamd_logger.errx(task, 'bad reply: %s', body)
+        return
+      end
 
-    local reply = settings.reply_conversion(task, body)
-    if not reply then
-      return
-    end
+      local reply = settings.reply_conversion(task, body)
 
-    if reply > 0.75 then
-      task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
-      if settings.autolearn then
-        task:set_flag("learn_spam")
-      end
-    elseif reply < 0.25 then
-      task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
-      if settings.autolearn then
-        task:set_flag("learn_ham")
+      results[idx].model = model
+
+      if reply then
+        results[idx].success = true
+        results[idx].probability = reply
       end
-    else
-      lua_util.debugm(N, task, "uncertain result: %s", reply)
-    end
 
+      check_consensus(task, results)
+    end
   end
 
   local from_content, url_content = get_meta_llm_content(task)
@@ -424,24 +461,38 @@ local function default_llm_check(task)
     body.response_format = { type = "json_object" }
   end
 
+  if type(settings.model) == 'string' then
+    settings.model = { settings.model }
+  end
+
   upstream = settings.upstreams:get_upstream_round_robin()
-  local http_params = {
-    url = settings.url,
-    mime_type = 'application/json',
-    timeout = settings.timeout,
-    log_obj = task,
-    callback = on_reply,
-    headers = {
-      ['Authorization'] = 'Bearer ' .. settings.api_key,
-    },
-    keepalive = true,
-    body = ucl.to_format(body, 'json-compact', true),
-    task = task,
-    upstream = upstream,
-    use_gzip = true,
-  }
+  for idx, model in ipairs(settings.model) do
+    results[idx] = {
+      success = false,
+      checked = false
+    }
+    body.model = model
+    local http_params = {
+      url = settings.url,
+      mime_type = 'application/json',
+      timeout = settings.timeout,
+      log_obj = task,
+      callback = gen_reply_closure(model, idx),
+      headers = {
+        ['Authorization'] = 'Bearer ' .. settings.api_key,
+      },
+      keepalive = true,
+      body = ucl.to_format(body, 'json-compact', true),
+      task = task,
+      upstream = upstream,
+      use_gzip = true,
+    }
 
-  rspamd_http.request(http_params)
+    if not rspamd_http.request(http_params) then
+      results[idx].checked = true
+    end
+
+  end
 end
 
 local function ollama_check(task)
@@ -460,45 +511,44 @@ local function ollama_check(task)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
 
   local upstream
+  local results = {}
+
+  local function gen_reply_closure(model, idx)
+    return function(err, code, body)
+      results[idx].checked = true
+      if err then
+        rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+        upstream:fail()
+        check_consensus(task, results)
+        return
+      end
 
-  local function on_reply(err, code, body)
-
-    if err then
-      rspamd_logger.errx(task, 'request failed: %s', err)
-      upstream:fail()
-      return
-    end
+      upstream:ok()
+      lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+      if code ~= 200 then
+        rspamd_logger.errx(task, 'bad reply: %s', body)
+        return
+      end
 
-    upstream:ok()
-    lua_util.debugm(N, task, "got reply: %s", body)
-    if code ~= 200 then
-      rspamd_logger.errx(task, 'bad reply: %s', body)
-      return
-    end
+      local reply = settings.reply_conversion(task, body)
 
-    local reply = settings.reply_conversion(task, body)
-    if not reply then
-      return
-    end
+      results[idx].model = model
 
-    if reply > 0.75 then
-      task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
-      if settings.autolearn then
-        task:set_flag("learn_spam")
-      end
-    elseif reply < 0.25 then
-      task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
-      if settings.autolearn then
-        task:set_flag("learn_ham")
+      if reply then
+        results[idx].success = true
+        results[idx].probability = reply
       end
-    else
-      lua_util.debugm(N, task, "uncertain result: %s", reply)
-    end
 
+      check_consensus(task, results)
+    end
   end
 
   local from_content, url_content = get_meta_llm_content(task)
 
+  if type(settings.model) == 'string' then
+    settings.model = { settings.model }
+  end
+
   local body = {
     stream = false,
     model = settings.model,
@@ -528,26 +578,30 @@ local function ollama_check(task)
     }
   }
 
-  -- Conditionally add response_format
-  if settings.include_response_format then
-    body.response_format = { type = "json_object" }
-  end
+  for i, model in ipairs(settings.model) do
+    -- Conditionally add response_format
+    if settings.include_response_format then
+      body.response_format = { type = "json_object" }
+    end
 
-  upstream = settings.upstreams:get_upstream_round_robin()
-  local http_params = {
-    url = settings.url,
-    mime_type = 'application/json',
-    timeout = settings.timeout,
-    log_obj = task,
-    callback = on_reply,
-    keepalive = true,
-    body = ucl.to_format(body, 'json-compact', true),
-    task = task,
-    upstream = upstream,
-    use_gzip = true,
-  }
+    body.model = model
+
+    upstream = settings.upstreams:get_upstream_round_robin()
+    local http_params = {
+      url = settings.url,
+      mime_type = 'application/json',
+      timeout = settings.timeout,
+      log_obj = task,
+      callback = gen_reply_closure(model, i),
+      keepalive = true,
+      body = ucl.to_format(body, 'json-compact', true),
+      task = task,
+      upstream = upstream,
+      use_gzip = true,
+    }
 
-  rspamd_http.request(http_params)
+    rspamd_http.request(http_params)
+  end
 end
 
 local function gpt_check(task)