]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Apply changes to bayes_expiry plugin
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 28 Jul 2025 18:22:56 +0000 (19:22 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 28 Jul 2025 18:22:56 +0000 (19:22 +0100)
src/plugins/lua/bayes_expiry.lua

index 44ff9dafaaeb2ae51d2477c9799773b3fd62ef8e..0d78f227250ce6cb309a66f3f7fce543522a10a3 100644 (file)
@@ -41,32 +41,38 @@ local template = {}
 local function check_redis_classifier(cls, cfg)
   -- Skip old classifiers
   if cls.new_schema then
-    local symbol_spam, symbol_ham
+    local class_symbols = {}
+    local class_labels = {}
     local expiry = (cls.expiry or cls.expire)
     if type(expiry) == 'table' then
       expiry = expiry[1]
     end
 
-    -- Load symbols from statfiles
+    -- Extract class_labels mapping from classifier config
+    if cls.class_labels then
+      class_labels = cls.class_labels
+    end
 
+    -- Load symbols from statfiles for multi-class support
     local function check_statfile_table(tbl, def_sym)
       local symbol = tbl.symbol or def_sym
-
-      local spam
-      if tbl.spam then
-        spam = tbl.spam
-      else
-        if string.match(symbol:upper(), 'SPAM') then
-          spam = true
+      local class_name = tbl.class
+
+      -- Handle legacy spam/ham detection for backward compatibility
+      if not class_name then
+        if tbl.spam ~= nil then
+          class_name = tbl.spam and 'spam' or 'ham'
+        elseif string.match(tostring(symbol):upper(), 'SPAM') then
+          class_name = 'spam'
+        elseif string.match(tostring(symbol):upper(), 'HAM') then
+          class_name = 'ham'
         else
-          spam = false
+          class_name = def_sym
         end
       end
 
-      if spam then
-        symbol_spam = symbol
-      else
-        symbol_ham = symbol
+      if class_name then
+        class_symbols[class_name] = symbol
       end
     end
 
@@ -87,10 +93,9 @@ local function check_redis_classifier(cls, cfg)
       end
     end
 
-    if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then
+    if next(class_symbols) == nil or type(expiry) ~= 'number' then
       logger.debugm(N, rspamd_config,
-          'disable expiry for classifier %s: no expiry %s',
-          symbol_spam, cls)
+          'disable expiry for classifier: no class symbols or expiry configured')
       return
     end
     -- Now try to load redis_params if needed
@@ -108,17 +113,16 @@ local function check_redis_classifier(cls, cfg)
     end
 
     if redis_params['read_only'] then
-      logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration',
-          symbol_spam)
+      logger.infox(rspamd_config, 'disable expiry for classifier: read only redis configuration')
       return
     end
 
-    logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry",
-        symbol_spam, symbol_ham, expiry)
+    logger.debugm(N, rspamd_config, "enabled expiry for classes %s -> %s expiry",
+        table.concat(lutil.keys(class_symbols), ', '), expiry)
 
     table.insert(settings.classifiers, {
-      symbol_spam = symbol_spam,
-      symbol_ham = symbol_ham,
+      class_symbols = class_symbols,
+      class_labels = class_labels,
       redis_params = redis_params,
       expiry = expiry
     })
@@ -249,12 +253,11 @@ local expiry_script = [[
   local keys = ret[2]
   local tokens = {}
 
-  -- Tokens occurrences distribution counters
+  -- Dynamic occurrence tracking for all classes
   local occur = {
-    ham = {},
-    spam = {},
     total = {}
   }
+  local classes_found = {}
 
   -- Expiry step statistics counters
   local nelts, extended, discriminated, sum, sum_squares, common, significant,
@@ -264,24 +267,44 @@ local expiry_script = [[
   for _,key in ipairs(keys) do
     local t = redis.call('TYPE', key)["ok"]
     if t == 'hash' then
-      local values = redis.call('HMGET', key, 'H', 'S')
-      local ham = tonumber(values[1]) or 0
-      local spam = tonumber(values[2]) or 0
+      -- Get all hash fields to support multi-class
+      local hash_data = redis.call('HGETALL', key)
+      local class_counts = {}
+      local total = 0
       local ttl = redis.call('TTL', key)
+
+      -- Parse hash data into class counts
+      for i = 1, #hash_data, 2 do
+        local class_label = hash_data[i]
+        local count = tonumber(hash_data[i + 1]) or 0
+        class_counts[class_label] = count
+        total = total + count
+
+        -- Track classes we've seen
+        if not classes_found[class_label] then
+          classes_found[class_label] = true
+          occur[class_label] = {}
+        end
+      end
+
       tokens[key] = {
-        ham,
-        spam,
-        ttl
+        class_counts = class_counts,
+        total = total,
+        ttl = ttl
       }
-      local total = spam + ham
+
       sum = sum + total
       sum_squares = sum_squares + total * total
       nelts = nelts + 1
 
-      for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do
-        if tonumber(v) > 19 then v = 20 end
-        occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1
+      -- Update occurrence counters for all classes and total
+      for class_label, count in pairs(class_counts) do
+        local bucket = count > 19 and 20 or count
+        occur[class_label][bucket] = (occur[class_label][bucket] or 0) + 1
       end
+
+      local total_bucket = total > 19 and 20 or total
+      occur.total[total_bucket] = (occur.total[total_bucket] or 0) + 1
     end
   end
 
@@ -293,9 +316,10 @@ local expiry_script = [[
   end
 
   for key,token in pairs(tokens) do
-    local ham, spam, ttl = token[1], token[2], tonumber(token[3])
+    local class_counts = token.class_counts
+    local total = token.total
+    local ttl = tonumber(token.ttl)
     local threshold = mean
-    local total = spam + ham
 
     local function set_ttl()
       if expire < 0 then
@@ -310,14 +334,39 @@ local expiry_script = [[
       return 0
     end
 
-    if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
+    -- Check if token is common (balanced across classes)
+    local is_common = false
+    if total == 0 then
+      is_common = true
+    else
+      -- For multi-class, check if any class dominates significantly
+      local max_count = 0
+      for _, count in pairs(class_counts) do
+        if count > max_count then
+          max_count = count
+        end
+      end
+      -- Token is common if no class has more than (1 - epsilon) of total
+      is_common = (max_count / total) <= (1 - ${epsilon_common})
+    end
+
+    if is_common then
       common = common + 1
       if ttl > ${common_ttl} then
         discriminated = discriminated + 1
         redis.call('EXPIRE', key, ${common_ttl})
       end
     elseif total >= threshold and total > 0 then
-      if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
+      -- Check if any class is significant
+      local is_significant = false
+      for _, count in pairs(class_counts) do
+        if count / total > ${significant_factor} then
+          is_significant = true
+          break
+        end
+      end
+
+      if is_significant then
         significant = significant + 1
         if ttl ~= -1 then
           redis.call('PERSIST', key)
@@ -361,33 +410,50 @@ local expiry_script = [[
   redis.call('DEL', lock_key)
 
   local occ_distr = {}
-  for _,cl in pairs({'ham', 'spam', 'total'}) do
+
+  -- Process all classes found plus total
+  local all_classes = {'total'}
+  for class_label in pairs(classes_found) do
+    table.insert(all_classes, class_label)
+  end
+
+  for _, cl in ipairs(all_classes) do
     local occur_key = pattern_sha1 .. '_occurrence_' .. cl
 
     if cursor ~= 0 then
-      local n
-      for i,v in ipairs(redis.call('HGETALL', occur_key)) do
-        if i % 2 == 1 then
-          n = tonumber(v)
-        else
-          occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v
+      local existing_data = redis.call('HGETALL', occur_key)
+      if #existing_data > 0 then
+        for i = 1, #existing_data, 2 do
+          local bucket = tonumber(existing_data[i])
+          local count = tonumber(existing_data[i + 1])
+          if occur[cl] and occur[cl][bucket] then
+            occur[cl][bucket] = occur[cl][bucket] + count
+          elseif occur[cl] then
+            occur[cl][bucket] = count
+          end
         end
       end
 
-      local str = ''
-      if occur[cl][0] ~= nil then
-        str = '0:' .. occur[cl][0] .. ','
-      end
-      for k,v in ipairs(occur[cl]) do
-        if k == 20 then k = '>19' end
-        str = str .. k .. ':' .. v .. ','
+      if occur[cl] and next(occur[cl]) then
+        local str = ''
+        if occur[cl][0] then
+          str = '0:' .. occur[cl][0] .. ','
+        end
+        for k = 1, 20 do
+          if occur[cl][k] then
+            local label = k == 20 and '>19' or tostring(k)
+            str = str .. label .. ':' .. occur[cl][k] .. ','
+          end
+        end
+        table.insert(occ_distr, cl .. '=' .. str)
+      else
+        table.insert(occ_distr, cl .. '=no_data')
       end
-      table.insert(occ_distr, str)
     else
       redis.call('DEL', occur_key)
     end
 
-    if next(occur[cl]) ~= nil then
+    if occur[cl] and next(occur[cl]) then
       redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl])))
     end
   end
@@ -446,8 +512,8 @@ local function expire_step(cls, ev_base, worker)
                 '%s infrequent (%s %s), %s mean, %s std',
             lutil.unpack(d))
         if cycle then
-          for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do
-            logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i])
+          for _, distr_info in ipairs(occ_distr) do
+            logger.infox(rspamd_config, 'tokens occurrences: {%s}', distr_info)
           end
         end
       end