]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Rework] GPT: Use cache framework
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 6 Mar 2025 12:39:22 +0000 (12:39 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 6 Mar 2025 12:39:22 +0000 (12:39 +0000)
src/plugins/lua/gpt.lua

index 625450fd94018fa89e5708dec97aaf1f9607467c..3ac95a1c9f4dcdf9c15f4e89b0a8ebdef0722cfe 100644 (file)
@@ -15,7 +15,7 @@ limitations under the License.
 ]] --
 
 local N = "gpt"
-local REDIS_PREFIX = "rsllm_"
+local REDIS_PREFIX = "rsllm"
 local E = {}
 
 if confighelp then
@@ -65,6 +65,7 @@ local lua_mime = require "lua_mime"
 local lua_redis = require "lua_redis"
 local ucl = require "ucl"
 local fun = require "fun"
+local lua_cache = require "lua_cache"
 
 -- Exclude checks if one of those is found
 local default_symbols_to_except = {
@@ -120,10 +121,11 @@ local settings = {
   allow_passthrough = false,
   allow_ham = false,
   json = false,
-  redis_cache_expire = 3600 * 24,
   extra_symbols = nil,
+  cache_prefix = REDIS_PREFIX,
 }
 local redis_params
+local cache_context
 
 local function default_condition(task)
   -- Check result
@@ -474,27 +476,10 @@ local function redis_cache_key(sel_part)
     digest:update(settings.url)
     env_digest = digest:hex():sub(1, 4)
   end
-  return string.format('%s%s_%s', REDIS_PREFIX, env_digest,
+  return string.format('%s_%s', env_digest,
       sel_part:get_mimepart():get_digest():sub(1, 24))
 end
 
-local function maybe_save_cache(task, result, sel_part)
-  if not sel_part or not redis_params then
-    lua_util.debugm(N, task, 'cannot save cache: no part or no redis')
-    return -- cannot save intentionally
-  end
-
-  local cache_key = redis_cache_key(sel_part)
-  lua_util.debugm(N, task, 'saving cache for %s', cache_key)
-  local result_json = ucl.to_format(result, 'json-compact')
-  lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _)
-    if err then
-      rspamd_logger.errx(task, 'cannot save cache: %s', err)
-    end
-  end,
-      'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
-end
-
 local function process_categories(task, categories)
   for _, category in ipairs(categories) do
     local sym = categories_map[category:lower()]
@@ -531,7 +516,10 @@ local function insert_results(task, result, sel_part)
       process_categories(task, result.categories)
     end
   end
-  maybe_save_cache(task, result, sel_part)
+
+  if cache_context then
+    lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context)
+  end
 end
 
 local function check_consensus_and_insert_results(task, results, sel_part)
@@ -613,32 +601,21 @@ end
 local function check_llm_cached(task, content, sel_part)
   local cache_key = redis_cache_key(sel_part)
 
-  local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, data)
+  lua_cache.cache_get(task, cache_key, cache_context, settings.timeout * 1.5, function()
+    check_llm_uncached(task, content, sel_part)
+  end, function(err, data)
     if err then
-      rspamd_logger.errx(task, 'cannot check cache: %s', err)
+      rspamd_logger.errx(task, 'cannot get cache: %s', err)
       check_llm_uncached(task, content, sel_part)
     end
 
-    if type(data) == 'string' then
-      local parser = ucl.parser()
-      local res, parse_err = parser:parse_string(data)
-      if not res then
-        rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err)
-        check_llm_uncached(task, content, sel_part)
-      else
-        rspamd_logger.infox(task, 'found cached response %s', cache_key)
-        insert_results(task, parser:get_object())
-      end
+    if data then
+      rspamd_logger.infox(task, 'found cached response %s', cache_key)
+      insert_results(task, data, sel_part)
     else
       check_llm_uncached(task, content, sel_part)
     end
-  end,
-      'GET', { cache_key })
-
-  if not ret then
-    rspamd_logger.errx(task, 'cannot query cache for request')
-    check_llm_uncached(task, content, sel_part)
-  end
+  end)
 end
 
 local function openai_check(task, content, sel_part)
@@ -896,10 +873,7 @@ if opts then
   settings = lua_util.override_defaults(settings, opts)
 
   if redis_params then
-    lua_redis.register_prefix(REDIS_PREFIX .. '*', N,
-        'Cache of LLM requests', {
-          type = 'string',
-        })
+    cache_context = lua_cache.create_cache_context(redis_params, settings, N)
   end
 
   if not settings.symbols_to_except then