]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add clustering logic
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 26 Sep 2018 17:21:17 +0000 (18:21 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 26 Sep 2018 17:21:52 +0000 (18:21 +0100)
src/plugins/lua/clustering.lua

index 6a7fd2c8de00acaa2505767bdd7036ff14433017..d6c78ef79663afb323949b1dfd4f467f4c316966 100644 (file)
@@ -20,17 +20,15 @@ end
 
 -- Plugin for finding patterns in email flows
 
-local E = {}
 local N = 'clustering'
 
 local rspamd_logger = require "rspamd_logger"
 local lua_util = require "lua_util"
 local lua_redis = require "lua_redis"
-local fun = require "fun"
 local lua_selectors = require "lua_selectors"
 local ts = require("tableshape").types
 
-local redis_params = nil
+local redis_params
 
 local rules = {} -- Rules placement
 
@@ -128,11 +126,116 @@ local update_cluster_id
 -- Callbacks and logic
 
 local function clusterting_filter_cb(task, rule)
+  local source_selector = rule.source_selector(task)
+  local cluster_selector
 
+  if source_selector then
+    cluster_selector = rule.cluster_selector(task)
+  end
+
+  if not cluster_selector or not source_selector then
+    rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
+        rule.name, source_selector, cluster_selector)
+    return
+  end
+
+  local function combine_scores(cur_elts, total_score, element_score)
+    local final_score
+
+    local size_score = cur_elts * rule.size_mult
+    local cluster_score = total_score * rule.score_mult
+
+    if element_score > 0 then
+      -- We have seen this element mostly in junk/spam
+      final_score = math.min(1.0, size_score + cluster_score)
+    else
+      -- We have seen this element in ham mostly, so subtract average it from the size score
+      final_score = math.min(1.0, size_score - cluster_score / cur_elts)
+    end
+    rspamd_logger.debugm(N, task, 'processed rule %s, selectors: source="%s", cluster="%s"; data: %s elts, %s score, %s elt score',
+        rule.name, source_selector, cluster_selector, cur_elts, total_score, element_score)
+    if final_score > 0.1 then
+      task:insert_result(rule.symbol, final_score, {source_selector,
+                                                    tostring(size_score),
+                                                    tostring(cluster_score)})
+    end
+  end
+
+  local function redis_get_cb(err, data)
+    if data then
+      if type(data) == 'table' then
+        combine_scores(tonumber(data[1]), tonumber(data[2]), tonumber(data[3]))
+      else
+        rspamd_logger.errx(task, 'invalid type while getting clustering keys %s: %s',
+            source_selector, type(data))
+      end
+
+    elseif err then
+      rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
+          source_selector, err)
+    else
+      rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
+          source_selector, "unknown error")
+    end
+  end
+
+  lua_redis.exec_redis_script(query_cluster_id,
+      {task = task, is_write = false, key = source_selector},
+      redis_get_cb,
+      {source_selector, cluster_selector})
 end
 
 local function clusterting_idempotent_cb(task, rule)
+  local action = task:get_action()
+  local score
+
+  if action == 'no action' then
+    score = rule.ham_mult
+  elseif action == 'reject' then
+    score = rule.spam_mult
+  elseif action == 'add header' or action == 'rewrite subject' then
+    score = rule.junk_mult
+  else
+    rspamd_logger.debugm(N, task, 'skip rule %s, action=%s',
+        rule.name, action)
+    return
+  end
+
+  local source_selector = rule.source_selector(task)
+  local cluster_selector
+
+  if source_selector then
+    cluster_selector = rule.cluster_selector(task)
+  end
 
+  if not cluster_selector or not source_selector then
+    rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"',
+        rule.name, source_selector, cluster_selector)
+    return
+  end
+
+  local function redis_set_cb(err, data)
+    if err then
+      rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s',
+          source_selector, err)
+    else
+      rspamd_logger.debugm(task, 'set clustering key for %s: %s{%s} = %s',
+          source_selector, "unknown error")
+    end
+  end
+
+  lua_redis.exec_redis_script(update_cluster_id,
+      {task = task, is_write = true, key = source_selector},
+      redis_set_cb,
+      {
+        source_selector,
+        cluster_selector,
+        tostring(score),
+        tostring(rule.max_elts),
+        tostring(rule.expire),
+        tostring(rule.expire_overflow)
+      }
+  )
 end
 -- Init part
 redis_params = lua_redis.parse_redis_server('clustering')
@@ -168,6 +271,7 @@ if opts['rules'] then
       rule.cluster_selector =  lua_selectors.create_selector_closure(rspamd_config,
           rule.cluster_selector, '')
       if rule.source_selector and rule.cluster_selector then
+        rule.name = k
         table.insert(rules, rule)
       end
     end