]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Minor changes to contextal plugin 5360/head
authorAndrew Lewis <nerf@judo.za.org>
Mon, 24 Mar 2025 11:53:15 +0000 (13:53 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Mon, 24 Mar 2025 11:54:51 +0000 (13:54 +0200)
 - Abandon prefilter
 - Abandon postfilter in favour of wait if we have request_ttl
 - Reformat table

src/plugins/lua/contextal.lua

index 5cebff238fcab78fbc7e29c198bd3feb1d3d354c..f6202781a822b63438b58d2ad838c8d96890d974 100644 (file)
@@ -38,11 +38,11 @@ local ucl = require "ucl"
 local cache_context, redis_params
 
 local contextal_actions = {
-  'ALERT',
-  'ALLOW',
-  'BLOCK',
-  'QUARANTINE',
-  'SPAM',
+  ['ALERT'] = true,
+  ['ALLOW'] = true,
+  ['BLOCK'] = true,
+  ['QUARANTINE'] = true,
+  ['SPAM'] = true,
 }
 
 local config_schema = lua_redis.enrich_schema {
@@ -75,6 +75,7 @@ local settings = {
 }
 
 local static_boundary = rspamd_util.random_hex(32)
+local use_request_ttl = true
 
 local function maybe_defer(task, obj)
   if settings.defer_if_no_result and not ((obj or E)[1] or E).actions then
@@ -132,6 +133,46 @@ local function process_cached(task, obj)
   end
 end
 
+local function action_cb(task)
+  local work_id = task:get_mempool():get_variable('contextal_work_id', 'string')
+  if not work_id then
+    rspamd_logger.err(task, 'no work id found in mempool')
+    return
+  end
+
+  local function http_callback(err, code, body, hdrs)
+    if err then
+      rspamd_logger.err(task, 'http error: %s', err)
+      maybe_defer(task)
+      return
+    end
+    if code ~= 200 then
+      rspamd_logger.err(task, 'bad http code: %s', code)
+      maybe_defer(task)
+      return
+    end
+    local parser = ucl.parser()
+    local _, parse_err = parser:parse_string(body)
+    if parse_err then
+      rspamd_logger.err(task, 'cannot parse JSON: %s', err)
+      maybe_defer(task)
+      return
+    end
+    local obj = parser:get_object()
+    return process_actions(task, obj, false)
+  end
+
+  rspamd_http.request({
+      task = task,
+      url = settings.actions_url .. work_id,
+      callback = http_callback,
+      timeout = settings.http_timeout,
+      gzip = settings.gzip,
+      keepalive = settings.keepalive,
+      no_ssl_verify = settings.no_ssl_verify,
+  })
+end
+
 local function submit(task)
 
   local function http_callback(err, code, body, hdrs)
@@ -159,6 +200,7 @@ local function submit(task)
     end
     task:insert_result(settings.submission_symbol, 1.0,
         string.format('work_id=%s', work_id or 'nil'))
+    task:add_timer(settings.request_ttl, action_cb)
   end
 
   local req = {
@@ -207,46 +249,6 @@ local function submit_cb(task)
   end
 end
 
-local function action_cb(task)
-  local work_id = task:get_mempool():get_variable('contextal_work_id', 'string')
-  if not work_id then
-    rspamd_logger.err(task, 'no work id found in mempool')
-    return
-  end
-
-  local function http_callback(err, code, body, hdrs)
-    if err then
-      rspamd_logger.err(task, 'http error: %s', err)
-      maybe_defer(task)
-      return
-    end
-    if code ~= 200 then
-      rspamd_logger.err(task, 'bad http code: %s', code)
-      maybe_defer(task)
-      return
-    end
-    local parser = ucl.parser()
-    local _, parse_err = parser:parse_string(body)
-    if parse_err then
-      rspamd_logger.err(task, 'cannot parse JSON: %s', err)
-      maybe_defer(task)
-      return
-    end
-    local obj = parser:get_object()
-    return process_actions(task, obj, false)
-  end
-
-  rspamd_http.request({
-      task = task,
-      url = settings.actions_url .. work_id,
-      callback = http_callback,
-      timeout = settings.http_timeout,
-      gzip = settings.gzip,
-      keepalive = settings.keepalive,
-      no_ssl_verify = settings.no_ssl_verify,
-  })
-end
-
 local function set_url_path(base, path)
   local slash = base:sub(#base) == '/' and '' or '/'
   return base .. slash .. path
@@ -263,7 +265,6 @@ if not res then
   return
 end
 
-contextal_actions = lua_util.list_to_hash(contextal_actions)
 for _, k in ipairs(settings.custom_actions) do
   contextal_actions[k] = true
 end
@@ -293,26 +294,38 @@ if redis_params then
   })
 end
 
-rspamd_config:register_symbol({
+local submission_id = rspamd_config:register_symbol({
   name = settings.submission_symbol,
-  priority = lua_util.symbols_priorities.top,
-  type = 'prefilter',
+  type = 'normal',
   group = N,
   callback = submit_cb
 })
 
-local id = rspamd_config:register_symbol({
-  name = settings.action_symbol_prefix,
-  type = 'postfilter',
-  priority = lua_util.symbols_priorities.high - 1,
-  group = N,
-  callback = action_cb
-})
+local top_options = rspamd_config:get_all_opt('options')
+if settings.request_ttl and settings.request_ttl >= (top_options.task_timeout * 0.8) then
+  rspamd_logger.warn(rspamd_config, [[request ttl is >= 80% of task timeout, won't wait on processing]])
+  use_request_ttl = false
+elseif not settings.request_ttl then
+  use_request_ttl = false
+end
+
+local parent_id
+if use_request_ttl then
+  parent_id = submission_id
+else
+  parent_id = rspamd_config:register_symbol({
+    name = settings.action_symbol_prefix,
+    type = 'postfilter',
+    priority = lua_util.symbols_priorities.high - 1,
+    group = N,
+    callback = action_cb
+  })
+end
 
 for k in pairs(contextal_actions) do
   rspamd_config:register_symbol({
     name = settings.action_symbol_prefix .. '_' .. k,
-    parent = id,
+    parent = parent_id,
     type = 'virtual',
     group = N,
   })