]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Cache LLM replies
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 11:38:28 +0000 (11:38 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 25 Feb 2025 11:38:28 +0000 (11:38 +0000)
src/plugins/lua/gpt.lua

index 270d0fdfc5d45ecb0c6c044c9221db356d79da52..6c6d8d685106d20acc30d0e3acbcaf86c7a88a73 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 ]] --
 
 local N = "gpt"
+local REDIS_PREFIX = "rsllm_"
 local E = {}
 
 if confighelp then
@@ -61,6 +62,7 @@ local lua_util = require "lua_util"
 local rspamd_http = require "rspamd_http"
 local rspamd_logger = require "rspamd_logger"
 local lua_mime = require "lua_mime"
+local lua_redis = require "lua_redis"
 local ucl = require "ucl"
 local fun = require "fun"
 
@@ -92,7 +94,9 @@ local settings = {
   allow_passthrough = false,
   allow_ham = false,
   json = false,
+  redis_cache_expire = 3600 * 24,
 }
+local redis_params
 
 local function default_condition(task)
   -- Check result
@@ -176,10 +180,10 @@ local function default_condition(task)
     local words = sel_part:get_words('norm')
     nwords = #words
     if nwords > settings.max_tokens then
-      return true, table.concat(words, ' ', 1, settings.max_tokens)
+      return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part
     end
   end
-  return true, sel_part:get_content_oneline()
+  return true, sel_part:get_content_oneline(), sel_part
 end
 
 local function maybe_extract_json(str)
@@ -431,7 +435,48 @@ local function default_ollama_json_conversion(task, input)
   return
 end
 
-local function check_consensus(task, results)
+local function maybe_save_cache(task, result, sel_part)
+  if not sel_part or not redis_params then
+    lua_util.debugm(N, task, 'cannot save cache: no part or no redis')
+    return -- cannot save
+  end
+
+  local digest = sel_part:get_mimepart():get_digest()
+  local cache_key = REDIS_PREFIX .. digest
+  lua_util.debugm(N, task, 'saving cache for %s', cache_key)
+  local result_json = ucl.to_format(result, 'json-compact')
+  lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _)
+    if err then
+      rspamd_logger.errx(task, 'cannot save cache: %s', err)
+    end
+  end,
+      'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
+end
+
+local function insert_results(task, result, sel_part)
+  if not result.probability then
+    rspamd_logger.errx(task, 'no probability in result')
+    return
+  end
+  if result.probability > 0.5 then
+    task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability))
+    if settings.autolearn then
+      task:set_flag("learn_spam")
+    end
+  else
+    if result.reason and settings.reason_header then
+      lua_mime.modify_headers(task,
+          { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
+    end
+    task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability))
+    if settings.autolearn then
+      task:set_flag("learn_ham")
+    end
+  end
+  maybe_save_cache(task, result, sel_part)
+end
+
+local function check_consensus_and_insert_results(task, results, sel_part)
   for _, result in ipairs(results) do
     if not result.checked then
       return
@@ -466,24 +511,17 @@ local function check_consensus(task, results)
   local reason = reasons[1] or nil
 
   if nspam > nham and max_spam_prob > 0.75 then
-    task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
-    if settings.autolearn then
-      task:set_flag("learn_spam")
-    end
-
-    if reason and settings.reason_header then
-      lua_mime.modify_headers(task,
-          { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
-    end
+    insert_results(task, {
+      probability = max_spam_prob,
+      reason = reason,
+    },
+        sel_part)
   elseif nham > nspam and max_ham_prob < 0.25 then
-    task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
-    if settings.autolearn then
-      task:set_flag("learn_ham")
-    end
-    if reason and settings.reason_header then
-      lua_mime.modify_headers(task,
-          { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
-    end
+    insert_results(task, {
+      probability = max_ham_prob,
+      reason = reason,
+    },
+        sel_part)
   else
     -- No consensus
     lua_util.debugm(N, task, "no consensus")
@@ -508,19 +546,43 @@ local function get_meta_llm_content(task)
   return url_content, from_content
 end
 
-local function default_llm_check(task)
-  local ret, content = settings.condition(task)
+local function check_llm_uncached(task, content, sel_part)
+  return settings.specific_check(task, content, sel_part)
+end
 
-  if not ret then
-    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
-    return
-  end
+local function check_llm_cached(task, content, sel_part)
+  local digest = sel_part:get_mimepart():get_digest()
+  local cache_key = REDIS_PREFIX .. digest
 
-  if not content then
-    lua_util.debugm(N, task, "no content to send to gpt classification")
-    return
+  local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(_, err, data)
+    if err then
+      rspamd_logger.errx(task, 'cannot check cache: %s', err)
+      check_llm_uncached(task, content, sel_part)
+    end
+
+    if data then
+      local parser = ucl.parser()
+      local res, parse_err = parser:parse_string(data)
+      if not res then
+        rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err)
+        check_llm_uncached(task, content, sel_part)
+      else
+        rspamd_logger.infox(task, 'found cached response')
+        insert_results(task, parser:get_object())
+      end
+    else
+      check_llm_uncached(task, content, sel_part)
+    end
+  end,
+      'GET', { cache_key })
+
+  if not ret then
+    rspamd_logger.errx(task, 'cannot query cache for request')
+    check_llm_uncached(task, content, sel_part)
   end
+end
 
+local function openai_check(task, content, sel_part)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
 
   local upstream
@@ -533,7 +595,7 @@ local function default_llm_check(task)
       if err then
         rspamd_logger.errx(task, '%s: request failed: %s', model, err)
         upstream:fail()
-        check_consensus(task, results)
+        check_consensus_and_insert_results(task, results, sel_part)
         return
       end
 
@@ -554,7 +616,7 @@ local function default_llm_check(task)
         results[idx].reason = reason
       end
 
-      check_consensus(task, results)
+      check_consensus_and_insert_results(task, results, sel_part)
     end
   end
 
@@ -627,19 +689,7 @@ local function default_llm_check(task)
   end
 end
 
-local function ollama_check(task)
-  local ret, content = settings.condition(task)
-
-  if not ret then
-    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
-    return
-  end
-
-  if not content then
-    lua_util.debugm(N, task, "no content to send to gpt classification")
-    return
-  end
-
+local function ollama_check(task, content, sel_part)
   lua_util.debugm(N, task, "sending content to gpt: %s", content)
 
   local upstream
@@ -651,7 +701,7 @@ local function ollama_check(task)
       if err then
         rspamd_logger.errx(task, '%s: request failed: %s', model, err)
         upstream:fail()
-        check_consensus(task, results)
+        check_consensus_and_insert_results(task, results, sel_part)
         return
       end
 
@@ -672,7 +722,7 @@ local function ollama_check(task)
         results[idx].reason = reason
       end
 
-      check_consensus(task, results)
+      check_consensus_and_insert_results(task, results, sel_part)
     end
   end
 
@@ -738,12 +788,29 @@ local function ollama_check(task)
 end
 
 local function gpt_check(task)
-  return settings.specific_check(task)
+  local ret, content, sel_part = settings.condition(task)
+
+  if not ret then
+    rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+    return
+  end
+
+  if not content then
+    lua_util.debugm(N, task, "no content to send to gpt classification")
+    return
+  end
+
+  if sel_part then
+    -- Check digest
+    check_llm_cached(task, content, sel_part)
+  else
+    check_llm_uncached(task, content)
+  end
 end
 
 local types_map = {
   openai = {
-    check = default_llm_check,
+    check = openai_check,
     condition = default_condition,
     conversion = function(is_json)
       return is_json and default_openai_json_conversion or default_openai_plain_conversion
@@ -760,10 +827,18 @@ local types_map = {
   },
 }
 
-local opts = rspamd_config:get_all_opt('gpt')
+local opts = rspamd_config:get_all_opt(N)
 if opts then
+  redis_params = lua_redis.parse_redis_server(N, opts)
   settings = lua_util.override_defaults(settings, opts)
 
+  if redis_params then
+    lua_redis.register_prefix(REDIS_PREFIX .. '*', N,
+        'Cache of LLM requests', {
+          type = 'string',
+        })
+  end
+
   if not settings.prompt then
     settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
         "FROM and url domains. Evaluate spam probability (0-1). " ..