]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] GPT: Add ollama support 5262/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 18 Dec 2024 16:32:37 +0000 (16:32 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 18 Dec 2024 16:32:37 +0000 (16:32 +0000)
src/plugins/lua/gpt.lua

index 36938c0d187c50dfdbcb9e94eabc1944724b01d3..feccae73ff29f4ef2895eca9438c27748b414480 100644 (file)
@@ -22,7 +22,7 @@ if confighelp then
       "Performs postfiltering using GPT model",
       [[
 gpt {
-  # Supported types: openai
+  # Supported types: openai, ollama
   type = "openai";
   # Your key to access the API
   api_key = "xxx";
@@ -155,13 +155,17 @@ end
 
 local function maybe_extract_json(str)
   -- Find the first opening brace
-  local startPos = str:find("{")
+  local startPos, endPos = str:find('json%s*{')
+  if not startPos then
+    startPos, endPos = str:find('{')
+  end
   if not startPos then
     return nil
   end
 
+  startPos = endPos - 1
   local openBraces = 0
-  local endPos = startPos
+  endPos = startPos
   local len = #str
 
   -- Iterate through the string to find matching braces
@@ -225,6 +229,7 @@ local function default_conversion(task, input)
   reply = parser:get_object()
 
   if type(reply) == 'table' and reply.probability then
+    lua_util.debugm(N, task, 'extracted probability: %s', reply.probability)
     local spam_score = tonumber(reply.probability)
 
     if not spam_score then
@@ -249,7 +254,87 @@ local function default_conversion(task, input)
   return
 end
 
-local function openai_gpt_check(task)
+local function ollama_conversion(task, input)
+  local parser = ucl.parser()
+  local res, err = parser:parse_string(input)
+  if not res then
+    rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+    return
+  end
+  local reply = parser:get_object()
+  if not reply then
+    rspamd_logger.errx(task, 'cannot get object from reply')
+    return
+  end
+
+  if type(reply.message) ~= 'table' then
+    rspamd_logger.errx(task, 'bad message in reply')
+    return
+  end
+
+  local first_message = reply.message.content
+
+  if not first_message then
+    rspamd_logger.errx(task, 'no content in the first message')
+    return
+  end
+
+  -- Apply heuristic to extract JSON
+  first_message = maybe_extract_json(first_message) or first_message
+
+  parser = ucl.parser()
+  res, err = parser:parse_string(first_message)
+  if not res then
+    rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err)
+    return
+  end
+
+  reply = parser:get_object()
+
+  if type(reply) == 'table' and reply.probability then
+    lua_util.debugm(N, task, 'extracted probability: %s', reply.probability)
+    local spam_score = tonumber(reply.probability)
+
+    if not spam_score then
+      -- Maybe we need GPT to convert GPT reply here?
+      if reply.probability == "high" then
+        spam_score = 0.9
+      elseif reply.probability == "low" then
+        spam_score = 0.1
+      else
+        rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability)
+      end
+    end
+
+    if type(reply.usage) == 'table' then
+      rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+    end
+
+    return spam_score
+  end
+
+  rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+  return
+end
+
+local function get_meta_llm_content(task)
+  local url_content = "Url domains: no urls found"
+  if task:has_urls() then
+    local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+    url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
+      return u:get_tld() or ''
+    end, urls or {})), ', ')
+  end
+
+  local from_or_empty = ((task:get_from('mime') or E)[1] or E)
+  local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
+  lua_util.debugm(N, task, "gpt urls: %s", url_content)
+  lua_util.debugm(N, task, "gpt from: %s", from_content)
+
+  return url_content, from_content
+end
+
+local function default_llm_check(task)
   local ret, content = settings.condition(task)
 
   if not ret then
@@ -302,18 +387,7 @@ local function openai_gpt_check(task)
 
   end
 
-  local url_content = "Url domains: no urls found"
-  if task:has_urls() then
-    local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
-    url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
-      return u:get_tld() or ''
-    end, urls or {})), ', ')
-  end
-
-  local from_or_empty = ((task:get_from('mime') or E)[1] or E)
-  local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
-  lua_util.debugm(N, task, "gpt urls: %s", url_content)
-  lua_util.debugm(N, task, "gpt from: %s", from_content)
+  local from_content, url_content = get_meta_llm_content(task)
 
   local body = {
     model = settings.model,
@@ -364,43 +438,161 @@ local function openai_gpt_check(task)
   rspamd_http.request(http_params)
 end
 
+local function ollama_check(task)
+  local ret, content = settings.condition(task)
+
+  if not ret then
+    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+    return
+  end
+
+  if not content then
+    lua_util.debugm(N, task, "no content to send to gpt classification")
+    return
+  end
+
+  lua_util.debugm(N, task, "sending content to gpt: %s", content)
+
+  local upstream
+
+  local function on_reply(err, code, body)
+
+    if err then
+      rspamd_logger.errx(task, 'request failed: %s', err)
+      upstream:fail()
+      return
+    end
+
+    upstream:ok()
+    lua_util.debugm(N, task, "got reply: %s", body)
+    if code ~= 200 then
+      rspamd_logger.errx(task, 'bad reply: %s', body)
+      return
+    end
+
+    local reply = settings.reply_conversion(task, body)
+    if not reply then
+      return
+    end
+
+    if reply > 0.75 then
+      task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
+      if settings.autolearn then
+        task:set_flag("learn_spam")
+      end
+    elseif reply < 0.25 then
+      task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
+      if settings.autolearn then
+        task:set_flag("learn_ham")
+      end
+    else
+      lua_util.debugm(N, task, "uncertain result: %s", reply)
+    end
+
+  end
+
+  local from_content, url_content = get_meta_llm_content(task)
+
+  local body = {
+    stream = false,
+    model = settings.model,
+    max_tokens = settings.max_tokens,
+    temperature = settings.temperature,
+    response_format = { type = "json_object" },
+    messages = {
+      {
+        role = 'system',
+        content = settings.prompt
+      },
+      {
+        role = 'user',
+        content = 'Subject: ' .. task:get_subject() or '',
+      },
+      {
+        role = 'user',
+        content = from_content,
+      },
+      {
+        role = 'user',
+        content = url_content,
+      },
+      {
+        role = 'user',
+        content = content
+      }
+    }
+  }
+
+  upstream = settings.upstreams:get_upstream_round_robin()
+  local http_params = {
+    url = settings.url,
+    mime_type = 'application/json',
+    timeout = settings.timeout,
+    log_obj = task,
+    callback = on_reply,
+    keepalive = true,
+    body = ucl.to_format(body, 'json-compact', true),
+    task = task,
+    upstream = upstream,
+    use_gzip = true,
+  }
+
+  rspamd_http.request(http_params)
+end
+
 local function gpt_check(task)
   return settings.specific_check(task)
 end
 
+local types_map = {
+  openai = {
+    check = default_llm_check,
+    condition = default_condition,
+    conversion = default_conversion,
+    require_passkey = true,
+  },
+  ollama = {
+    check = ollama_check,
+    condition = default_condition,
+    conversion = ollama_conversion,
+    require_passkey = false,
+  },
+}
+
 local opts = rspamd_config:get_all_opt('gpt')
 if opts then
   settings = lua_util.override_defaults(settings, opts)
 
-  if not settings.api_key then
-    rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module')
-    lua_util.disable_module(N, "config")
+  if not settings.prompt then
+    settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
+        "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
+        "output result as JSON with 'probability' field."
+  end
 
+  local llm_type = types_map[settings.type]
+  if not llm_type then
+    rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+    lua_util.disable_module(N, "config")
     return
   end
+  settings.specific_check = llm_type.check
+
   if settings.condition then
     settings.condition = load(settings.condition)()
   else
-    settings.condition = default_condition
+    settings.condition = llm_type.condition
   end
 
   if settings.reply_conversion then
     settings.reply_conversion = load(settings.reply_conversion)()
   else
-    settings.reply_conversion = default_conversion
-  end
-
-  if not settings.prompt then
-    settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
-        "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
-        "output result as JSON with 'probability' field."
+    settings.reply_conversion = llm_type.conversion
   end
 
-  if settings.type == 'openai' then
-    settings.specific_check = openai_gpt_check
-  else
-    rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+  if not settings.api_key and llm_type.require_passkey then
+    rspamd_logger.warnx(rspamd_config, 'no api_key is specified for LLM type %s, disabling module', settings.type)
     lua_util.disable_module(N, "config")
+
     return
   end