]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
More changes to ipscore module.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Jul 2015 09:07:01 +0000 (10:07 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Jul 2015 09:07:01 +0000 (10:07 +0100)
src/plugins/lua/ip_score.lua

index b978963b8437c3ca39f55c5d79a07d1d4ac4e858..b1d116fc4781cd8f9cb541969a97017d3e3b3e94 100644 (file)
@@ -29,6 +29,7 @@ local rspamd_logger = require "rspamd_logger"
 local rspamd_redis = require "rspamd_redis"
 local upstream_list = require "rspamd_upstream_list"
 local rspamd_regexp = require "rspamd_regexp"
+local rspamd_util = require "rspamd_util"
 local _ = require "fun"
 
 -- Default settings
@@ -37,19 +38,27 @@ local upstreams = nil
 local whitelist = nil
 
 local options = {
-  asn_provider = 'origin.asn.cymru.com',
-  scores = {
+  asn_provider = 'origin.asn.cymru.com', -- provider for ASN data
+  actions = { -- how each action is treated in scoring
     ['reject'] = 1.0,
     ['add header'] = 0.25,
     ['rewrite subject'] = 0.25,
     ['no action'] = -1.0
   },
-  symbol = 'IP_SCORE',
-  hash = 'ip_score',
-  asn_prefix = 'a:',
-  country_prefix = 'c:',
-  ipnet_prefix = 'n:',
-  servers = '',
+  scores = { -- how each component is evaluated
+    ['asn'] = 0.5,
+    ['country'] = 0.1,
+    ['ipnet'] = 0.8,
+    ['ip'] = 1.0
+  },
+  symbol = 'IP_SCORE', -- symbol to be inserted
+  hash = 'ip_score', -- hash table in redis used for storing scores
+  asn_prefix = 'a:', -- prefix for ASN hashes
+  country_prefix = 'c:', -- prefix for country hashes
+  ipnet_prefix = 'n:', -- prefix for ipnet hashes
+  servers = '', -- list of servers
+  lower_bound = 10, -- minimum number of messages to be scored
+  metric = 'default'
 }
 
 local asn_re = rspamd_regexp.create_cached("[\\|\\s]")
@@ -106,7 +115,7 @@ local ip_score_set = function(task)
     end
   end
 
-  local action = task:get_metric_action(metric)
+  local action = task:get_metric_action(options['metric'])
   local ip = task:get_from_ip()
   if not ip or not ip:is_valid() then
     return
@@ -119,8 +128,15 @@ local ip_score_set = function(task)
       return
     end
   end
-
+  
+  local pool = task:get_mempool()
   local asn, country, ipnet = ip_score_get_task_vars(task)
+  local asn_score,total_asn,
+        country_score,total_country,
+        ipnet_score,total_ipnet,
+        ip_score, total_ip = pool:get_variable('ip_score', 
+        'double,double,double,double,double,double,double,double')
+  
 
   rspamd_logger.infox('%1', action)
   local score = 0
@@ -164,39 +180,77 @@ local ip_score_check = function(task)
 
   local ip_score_redis_cb = function(task, err, data)
     local function calculate_score(score)
-      -- Normalize
-      local nscore
-      if score > 0 and score > normalize_score then
-        nscore = 1
-      elseif score < 0 and score < -normalize_score then
-        nscore = -1
-      else
-        nscore = score / normalize_score
+      local parts = asn_re:split(score)
+      local rep = tonumber(parts[1])
+      local total = tonumber(parts[2])
+      
+      return rep, total
+    end
+    
+    local function normalize_score(sc, total, mult)
+      if total < options['lower_bound'] then
+        return 0
       end
       
-      return nscore
+      -- -mult to mult
+      return mult * rspamd_util.tanh(2.718 * sc / total)
     end
     
     if err then
       -- Key is not found or error occurred
       return
     elseif data then
+      -- Scores and total number of messages per bucket
+      local asn_score,total_asn,
+        country_score,total_country,
+        ipnet_score,total_ipnet,
+        ip_score, total_ip = 0, -1, 0, -1, 0, -1, 0, -1
       if data[1] and type(data[1]) ~= 'userdata' then
-        local asn_score = calculate_score(tonumber(data[1]))
-        task:insert_result(symbol, asn_score, 'asn: ' .. asn)
+        asn_score,total_asn = calculate_score(data[1])
       end
       if data[2] and type(data[2]) ~= 'userdata' then
-        local country_score = calculate_score(tonumber(data[2]))
-        task:insert_result(symbol, country_score, 'country: ' .. country)
+        country_score,total_country = calculate_score(data[2])
       end
       if data[3] and type(data[3]) ~= 'userdata' then
-        local ipnet_score = calculate_score(tonumber(data[3]))
-        task:insert_result(symbol, ipnet_score, 'ipnet: ' .. ipnet)
+        ipnet_score,total_ipnet = calculate_score(data[3])
       end
       if data[4] and type(data[4]) ~= 'userdata' then
-        local ip_score = calculate_score(tonumber(data[4]))
-        task:insert_result(symbol, ip_score, 'ip')
+        ip_score,total_ip = calculate_score(data[4])
       end
+      -- Save everything for the post filter
+      task:get_mempool():set_variable('ip_score', asn_score,total_asn,
+        country_score,total_country,
+        ipnet_score,total_ipnet,
+        ip_score, total_ip)
+      
+      asn_score = normalize_score(asn_score, total_asn, options['scores']['asn'])
+      country_score = normalize_score(country_score, total_country,
+        options['scores']['country'])
+      ipnet_score = normalize_score(ipnet_score, total_ipnet,
+        options['scores']['ipnet'])
+      ip_score = normalize_score(ip_score, total_ip, options['scores']['ip'])
+      
+      local total_score = 0.0
+      local description = ''
+      if ip_score ~= 0 then
+        total_score = total_score + ip_score
+        description = description .. 'ip,'
+      end
+      if asn_score ~= 0 then
+        total_score = total_score + asn_score
+        description = description .. 'asn:' .. asn .. ','
+      end
+      if country_score ~= 0 then
+        total_score = total_score + country_score
+        description = description .. 'country:' .. country .. ','
+      end
+      if ipnet_score ~= 0 then
+        total_score = total_score + ipnet_score
+        description = description .. 'ipnet:' .. ipnet .. ','
+      end
+      
+      if total_score ~= 0 then
+        task:insert_result(options['symbol'], total_score, description)
     end
   end