]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Plugin to integrate with Contextal platform
authorAndrew Lewis <nerf@judo.za.org>
Thu, 27 Feb 2025 20:23:32 +0000 (22:23 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Mon, 24 Mar 2025 11:54:51 +0000 (13:54 +0200)
lualib/lua_scanners/cloudmark.lua
lualib/lua_util.lua
src/plugins/lua/contextal.lua [new file with mode: 0644]

index 26a3bf9c46bc8ba880f33addd4c740a57edbb3d7..12a60abf1d5356fe2d4abd8660a51b35639ea2b3 100644 (file)
@@ -173,53 +173,6 @@ local function cloudmark_config(opts)
   return nil
 end
 
--- Converts a key-value map to the table representing multipart body, with the following values:
--- `data`: data of the part
--- `filename`: optional filename
--- `content-type`: content type of the element (optional)
--- `content-transfer-encoding`: optional CTE header
-local function table_to_multipart_body(tbl, boundary)
-  local seen_data = false
-  local out = {}
-
-  for k, v in pairs(tbl) do
-    if v.data then
-      seen_data = true
-      table.insert(out, string.format('--%s\r\n', boundary))
-      if v.filename then
-        table.insert(out,
-            string.format('Content-Disposition: form-data; name="%s"; filename="%s"\r\n',
-                k, v.filename))
-      else
-        table.insert(out,
-            string.format('Content-Disposition: form-data; name="%s"\r\n', k))
-      end
-      if v['content-type'] then
-        table.insert(out,
-            string.format('Content-Type: %s\r\n', v['content-type']))
-      else
-        table.insert(out, 'Content-Type: text/plain\r\n')
-      end
-      if v['content-transfer-encoding'] then
-        table.insert(out,
-            string.format('Content-Transfer-Encoding: %s\r\n',
-                v['content-transfer-encoding']))
-      else
-        table.insert(out, 'Content-Transfer-Encoding: binary\r\n')
-      end
-      table.insert(out, '\r\n')
-      table.insert(out, v.data)
-      table.insert(out, '\r\n')
-    end
-  end
-
-  if seen_data then
-    table.insert(out, string.format('--%s--\r\n', boundary))
-  end
-
-  return out
-end
-
 local function get_specific_symbol(scores_symbols, score)
   local selected
   local sel_thr = -1
@@ -359,7 +312,7 @@ local function cloudmark_check(task, content, digest, rule, maybe_part)
     local request_data = {
       task = task,
       url = url,
-      body = table_to_multipart_body(request, static_boundary),
+      body = lua_util.table_to_multipart_body(request, static_boundary),
       headers = {
         ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary)
       },
index 62b38c87e37734dad4aaabb94483751cd5245052..636212b1fe0d827394578c16aafb8aa21db7910c 100644 (file)
@@ -1805,4 +1805,55 @@ exports.symbols_priorities = {
   low = 0,
 }
 
+---[[[
+-- @function lua_util.table_to_multipart_body(tbl, boundary)
+-- Converts a key-value map to the table representing multipart body, with the following values:
+-- `data`: data of the part
+-- `filename`: optional filename
+-- `content-type`: content type of the element (optional)
+-- `content-transfer-encoding`: optional CTE header
+local function table_to_multipart_body(tbl, boundary)
+  local seen_data = false
+  local out = {}
+
+  for k, v in pairs(tbl) do
+    if v.data then
+      seen_data = true
+      table.insert(out, string.format('--%s\r\n', boundary))
+      if v.filename then
+        table.insert(out,
+            string.format('Content-Disposition: form-data; name="%s"; filename="%s"\r\n',
+                k, v.filename))
+      else
+        table.insert(out,
+            string.format('Content-Disposition: form-data; name="%s"\r\n', k))
+      end
+      if v['content-type'] then
+        table.insert(out,
+            string.format('Content-Type: %s\r\n', v['content-type']))
+      else
+        table.insert(out, 'Content-Type: text/plain\r\n')
+      end
+      if v['content-transfer-encoding'] then
+        table.insert(out,
+            string.format('Content-Transfer-Encoding: %s\r\n',
+                v['content-transfer-encoding']))
+      else
+        table.insert(out, 'Content-Transfer-Encoding: binary\r\n')
+      end
+      table.insert(out, '\r\n')
+      table.insert(out, v.data)
+      table.insert(out, '\r\n')
+    end
+  end
+
+  if seen_data then
+    table.insert(out, string.format('--%s--\r\n', boundary))
+  end
+
+  return out
+end
+
+exports.table_to_multipart_body = table_to_multipart_body
+
 return exports
diff --git a/src/plugins/lua/contextal.lua b/src/plugins/lua/contextal.lua
new file mode 100644 (file)
index 0000000..341b7a1
--- /dev/null
@@ -0,0 +1,302 @@
+--[[
+Copyright (c) 2025, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local E = {}
+local N = 'contextal'
+
+if confighelp then
+  return
+end
+
+local opts = rspamd_config:get_all_opt(N)
+if not opts then
+  return
+end
+
+local lua_redis = require "lua_redis"
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+
+local contextal_actions = {
+  'ALERT',
+  'ALLOW',
+  'BLOCK',
+  'QUARANTINE',
+  'SPAM',
+}
+
+local settings = {
+  action_symbol_prefix = 'CONTEXTAL_ACTION',
+  base_url = 'http://localhost:8080',
+  cache_ttl = 3600,
+  custom_actions = {},
+  http_timeout = 2,
+  key_prefix = 'CXAL',
+  request_ttl = 4,
+  submission_symbol = 'CONTEXTAL_SUBMIT',
+}
+
+local static_boundary = rspamd_util.random_hex(32)
+
+local function cache_key(task)
+  return string.format('%s_%s', settings.key_prefix, task:get_digest())
+end
+
+local function process_actions(task, obj, is_cached)
+  for _, match in ipairs((obj[1] or E).actions) do
+    local act = match.action
+    local scenario = match.scenario
+    if not (act and scenario) then
+      rspamd_logger.err(task, 'bad result: %s', match)
+    elseif contextal_actions[act] then
+      task:insert_result(settings.action_symbol_prefix .. '_' .. act, 1.0, scenario)
+    else
+      rspamd_logger.err(task, 'unknown action: %s', act)
+    end
+  end
+
+  if not redis_params or is_cached then return end
+
+  local cache_obj
+  if (obj[1] or E).actions then
+    cache_obj = {[1] = {["actions"] = obj[1].actions}}
+  elseif (obj[1] or E).work_id then
+    cache_obj = {[1] = {["work_id"] = obj[1].work_id}}
+  else
+    rspamd_logger.err(task, 'bad result: %s', obj)
+    return
+  end
+
+  local key = cache_key(task)
+  local ret = lua_redis.redis_make_request(task,
+      redis_params, -- connect params
+      key, -- hash key
+      true, -- is write
+      redis_get_cb, --callback
+      'SET', -- command
+      { key, ucl.to_format(cache_obj, 'json-compact') } -- arguments
+  )
+
+  if not ret then
+    rspamd_logger.err(task, 'cannot make redis request to cache result')
+    return
+  end
+end
+
+local function process_cached(task, txt)
+  local parser = ucl.parser()
+  local _, err = parser:parse_string(txt)
+  if err then
+    rspamd_logger.err(task, 'cannot parse JSON (cached): %s', err)
+    return
+  end
+  local obj = parser:get_object()
+  if (obj[1] or E).actions then
+    task:disable_symbol(settings.action_symbol_prefix)
+    return process_actions(task, obj, true)
+  elseif (obj[1] or E).work_id then
+    task:get_mempool():set_variable('contextal_work_id', obj.work_id)
+  else
+    rspamd_logger.err(task, 'bad result (cached): %s', obj)
+  end
+end
+
+local function submit(task)
+
+  local function http_callback(err, code, body, hdrs)
+    if err then
+      rspamd_logger.err(task, 'http error: %s', err)
+      return
+    end
+    if code ~= 201 then
+      rspamd_logger.err(task, 'bad http code: %s', code)
+      return
+    end
+    local parser = ucl.parser()
+    local _, err = parser:parse_string(body)
+    if err then
+      rspamd_logger.err(task, 'cannot parse JSON: %s', err)
+      return
+    end
+    local obj = parser:get_object()
+    local work_id = obj.work_id
+    if work_id then
+      task:get_mempool():set_variable('contextal_work_id', work_id)
+    end
+    task:insert_result(settings.submission_symbol, 1.0,
+        string.format('work_id=%s', work_id or 'nil'))
+  end
+
+  local req = {
+    object_data = {['data'] = task:get_content()},
+  }
+  if settings.request_ttl then
+    req.ttl = {['data'] = tostring(settings.request_ttl)}
+  end
+  if settings.max_recursion then
+    req.maxrec = {['data'] = tostring(settings.max_recursion)}
+  end
+  rspamd_http.request({
+      task = task,
+      url = settings.submit_url,
+      body = lua_util.table_to_multipart_body(req, static_boundary),
+      callback = http_callback,
+      headers = {
+        ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary)
+      },
+      timeout = settings.http_timeout,
+      gzip = settings.gzip,
+      keepalive = settings.keepalive,
+      no_ssl_verify = settings.no_ssl_verify,
+  })
+end
+
+local function submit_cb(task)
+  if redis_params then
+
+    local function redis_get_cb(err, data)
+      if err then
+        rspamd_logger.err(task, 'error querying redis: %s', err)
+        return
+      end
+      if type(data) == 'userdata' then
+        return submit(task)
+      end
+      process_cached(task, data)
+    end
+
+    local key = cache_key(task)
+    local ret = lua_redis.redis_make_request(task,
+        redis_params, -- connect params
+        key, -- hash key
+        false, -- is write
+        redis_get_cb, --callback
+        'GET', -- command
+        { key } -- arguments
+    )
+
+    if not ret then
+      rspamd_logger.err(task, 'cannot make redis request to check results')
+      return
+    end
+
+  else
+    return submit(task)
+  end
+end
+
+local function action_cb(task)
+  local work_id = task:get_mempool():get_variable('contextal_work_id', 'string')
+  if not work_id then
+    rspamd_logger.err(task, 'no work id found in mempool')
+    return
+  end
+
+  local function http_callback(err, code, body, hdrs)
+    if err then
+      rspamd_logger.err(task, 'http error: %s', err)
+      return
+    end
+    if code ~= 200 then
+      rspamd_logger.err(task, 'bad http code: %s', code)
+      return
+    end
+    local parser = ucl.parser()
+    local _, err = parser:parse_string(body)
+    if err then
+      rspamd_logger.err(task, 'cannot parse JSON: %s', err)
+      return
+    end
+    local obj = parser:get_object()
+    if (obj[1] or E).actions then
+      return process_actions(task, obj, false)
+    end
+  end
+
+  rspamd_http.request({
+      task = task,
+      url = settings.actions_url .. work_id,
+      callback = http_callback,
+      timeout = settings.http_timeout,
+      gzip = settings.gzip,
+      keepalive = settings.keepalive,
+      no_ssl_verify = settings.no_ssl_verify,
+  })
+end
+
+local function set_url_path(base, path)
+  local ts = base:sub(#base) == '/' and '' or '/'
+  return base .. ts .. path
+end
+
+local opts = rspamd_config:get_all_opt(N)
+if not opts then return end
+
+settings = lua_util.override_defaults(settings, opts)
+
+contextal_actions = lua_util.list_to_hash(contextal_actions)
+for _, k in ipairs(settings.custom_actions) do
+  contextal_actions[k] = true
+end
+
+if not settings.base_url then
+  if not (settings.submit_url and settings.actions_url) then
+    rspamd_logger.err(rspamd_config, 'no URL configured for contextal')
+    lua_util.disable_module(N, 'config')
+    return
+  end
+else
+  if not settings.submit_url then
+    settings.submit_url = set_url_path(settings.base_url, 'api/v1/submit')
+  end
+  if not settings.actions_url then
+    settings.actions_url = set_url_path(settings.base_url, 'api/v1/actions/')
+  end
+end
+
+redis_params = lua_redis.parse_redis_server(N)
+if redis_params then
+  lua_redis.register_prefix(settings.key_prefix .. '_*', N,
+      'Cache for contextal plugin')
+end
+
+rspamd_config:register_symbol({
+  name = settings.submission_symbol,
+  priority = lua_util.symbols_priorities.top,
+  type = 'prefilter',
+  group = N,
+  callback = submit_cb
+})
+
+local id = rspamd_config:register_symbol({
+  name = settings.action_symbol_prefix,
+  type = 'postfilter',
+  priority = lua_util.symbols_priorities.high - 1,
+  group = N,
+  callback = action_cb
+})
+
+for k in pairs(contextal_actions) do
+  rspamd_config:register_symbol({
+    name = settings.action_symbol_prefix .. '_' .. k,
+    parent = id,
+    type = 'virtual',
+    group = N,
+  })
+end