]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add shard migration and multi-class support to statistics_dump
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 27 Feb 2026 14:25:25 +0000 (14:25 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 27 Feb 2026 14:25:25 +0000 (14:25 +0000)
Add `rspamadm statistics_dump migrate` subcommand for migrating per-user
Bayes data between Redis shards after the Jump Hash to Ketama transition.
The tool scans all shards, identifies misplaced prefixes via
get_upstream_by_hash, and moves them in batches using Redis Lua scripts.

Also fix multi-class Bayes support: dump/restore now handles arbitrary
statfile classes (not just binary spam/ham) by collecting all symbols
from classifier config with proper class label mapping. The dump command
now iterates all shards via all_upstreams() for complete data export.

lualib/rspamadm/statistics_dump.lua

index e727c739f754296759fa08305851491a84171879..ac7f7d0796ba39b98956e1be6fb3815d40d0f1f4 100644 (file)
@@ -90,6 +90,19 @@ restore:option "-m --mode"
 restore:flag "-n --no-operation"
        :description "Only show redis commands to be issued"
 
+-- Migrate
+local migrate = parser:command "migrate m"
+                      :description "Migrate bayes data between shards (after hash algorithm change)"
+migrate:flag "-n --dry-run"
+       :description "Only show what would be migrated, without writing"
+migrate:flag "--no-delete"
+       :description "Copy keys to target shard without deleting from source"
+migrate:option "-b --batch-size"
+       :description "Number of entries to process per SCAN batch"
+       :argname("<elts>")
+       :convert(tonumber)
+       :default(1000)
+
 local function load_config(opts)
   local _r, err = rspamd_config:load_ucl(opts['config'])
 
@@ -109,27 +122,60 @@ local function check_redis_classifier(cls, cfg)
   -- Skip old classifiers
   if cls.new_schema then
     local symbol_spam, symbol_ham
+    local symbols = {}
     -- Load symbols from statfiles
 
+    local function get_class_label(class_name)
+      -- Check class_labels mapping in classifier config
+      if cls.class_labels and class_name then
+        local label = cls.class_labels[class_name]
+        if label then
+          return label
+        end
+      end
+      -- Default mapping: spam→S, ham→H, custom→class_name
+      if class_name == 'spam' then
+        return 'S'
+      elseif class_name == 'ham' then
+        return 'H'
+      end
+      return class_name
+    end
+
     local function check_statfile_table(tbl, def_sym)
       local symbol = tbl.symbol or def_sym
 
-      local spam
-      if tbl.spam then
-        spam = tbl.spam
+      -- Determine class_name by priority:
+      -- 1. Explicit tbl.class
+      -- 2. Legacy tbl.spam boolean
+      -- 3. Heuristic from symbol name
+      local class_name
+      if tbl.class then
+        class_name = tbl.class
+      elseif tbl.spam then
+        class_name = 'spam'
       else
         if string.match(symbol:upper(), 'SPAM') then
-          spam = true
+          class_name = 'spam'
         else
-          spam = false
+          class_name = 'ham'
         end
       end
 
-      if spam then
+      local label = get_class_label(class_name)
+
+      -- Backward compat for binary classifiers
+      if class_name == 'spam' then
         symbol_spam = symbol
-      else
+      elseif class_name == 'ham' then
         symbol_ham = symbol
       end
+
+      table.insert(symbols, {
+        symbol = symbol,
+        class_name = class_name,
+        label = label,
+      })
     end
 
     local statfiles = cls.statfile
@@ -174,6 +220,7 @@ local function check_redis_classifier(cls, cfg)
     table.insert(classifiers, {
       symbol_spam = symbol_spam,
       symbol_ham = symbol_ham,
+      symbols = symbols,
       redis_params = redis_params,
     })
   end
@@ -196,6 +243,54 @@ local clear_fcn = table.clear or function(tbl)
   end
 end
 
+local function connect_to_upstream(up, redis_params)
+  local rspamd_redis = require "rspamd_redis"
+  local ret, conn = rspamd_redis.connect_sync({
+    host = up:get_addr(),
+    timeout = redis_params.timeout,
+    config = rspamd_config,
+    ev_base = rspamadm_ev_base,
+    session = rspamadm_session,
+  })
+
+  if not ret or not conn then
+    rspamd_logger.errx("cannot connect to redis %s: %s", up:get_name(), conn)
+    return false, nil
+  end
+
+  local need_exec = false
+  if redis_params.username then
+    if redis_params.password then
+      conn:add_cmd('AUTH', { redis_params.username, redis_params.password })
+      need_exec = true
+    else
+      rspamd_logger.errx("redis requires a password when username is supplied")
+      return false, nil
+    end
+  elseif redis_params.password then
+    conn:add_cmd('AUTH', { redis_params.password })
+    need_exec = true
+  end
+
+  if redis_params.db then
+    conn:add_cmd('SELECT', { tostring(redis_params.db) })
+    need_exec = true
+  elseif redis_params.dbname then
+    conn:add_cmd('SELECT', { tostring(redis_params.dbname) })
+    need_exec = true
+  end
+
+  if need_exec then
+    local exec_ret, res = conn:exec()
+    if not exec_ret then
+      rspamd_logger.errx("cannot authenticate/select db on %s: %s", up:get_name(), res)
+      return false, nil
+    end
+  end
+
+  return true, conn
+end
+
 local compress_ctx
 
 local function dump_out(out, opts, last)
@@ -216,14 +311,26 @@ local function dump_out(out, opts, last)
   end
 end
 
-local function dump_cdb(out, opts, last, pattern)
+local function dump_cdb(out, opts, last, pattern, class_labels)
   local results = out[pattern]
 
   if not out.cdb_builder then
     -- First invocation
     out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern))
-    out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0'))
-    out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0'))
+    -- Write learned counts for all class labels
+    for _, lbl in ipairs(class_labels or { 'S', 'H' }) do
+      local learned_key
+      if lbl == 'S' then
+        learned_key = 'learns_spam'
+      elseif lbl == 'H' then
+        learned_key = 'learns_ham'
+      else
+        learned_key = 'learns_' .. lbl
+      end
+      -- Pad CDB key to 8 bytes for consistent lookup
+      local cdb_key = string.format('_lrn%-4s', lbl)
+      out.cdb_builder:add(cdb_key, rspamd_i64.fromstring(results[learned_key] or '0'))
+    end
   end
 
   for _, o in ipairs(results.elts) do
@@ -236,9 +343,15 @@ local function dump_cdb(out, opts, last, pattern)
   end
 end
 
-local function dump_pattern(conn, pattern, opts, out, key)
+local function dump_pattern(conn, pattern, opts, out, key, class_labels)
   local cursor = 0
 
+  -- Build CDB pack format string from class labels
+  local cdb_fmt
+  if opts.cdb then
+    cdb_fmt = string.rep('f', #class_labels)
+  end
+
   repeat
     conn:add_cmd('SCAN', { tostring(cursor),
                            'MATCH', pattern,
@@ -277,10 +390,14 @@ local function dump_pattern(conn, pattern, opts, out, key)
     for i, d in ipairs(tokens) do
       if cursor == 0 and i == #tokens or not opts.json then
         if opts.cdb then
+          -- Pack all class label values dynamically
+          local values = {}
+          for _, lbl in ipairs(class_labels) do
+            values[#values + 1] = tonumber(d.data[lbl] or '0') or 0
+          end
           table.insert(out[key].elts, {
             key = rspamd_i64.fromstring(string.match(d.key, '%d+')),
-            value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0,
-                tonumber(d.data["H"] or '0'))
+            value = rspamd_util.pack(cdb_fmt, lua_util.unpack(values))
           })
         else
           out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key,
@@ -300,14 +417,14 @@ local function dump_pattern(conn, pattern, opts, out, key)
     -- Do not write the last chunk of out as it will be processed afterwards
     if cursor ~= 0 then
       if opts.cdb then
-        dump_cdb(out, opts, false, key)
+        dump_cdb(out, opts, false, key, class_labels)
         out[key].elts = {}
       else
         dump_out(out, opts, false)
         clear_fcn(out)
       end
     elseif opts.cdb then
-      dump_cdb(out, opts, true, key)
+      dump_cdb(out, opts, true, key, class_labels)
     end
 
   until cursor == 0
@@ -316,22 +433,52 @@ end
 local function dump_handler(opts)
   local patterns_seen = {}
   for _, cls in ipairs(classifiers) do
-    local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false)
+    -- Collect class labels for CDB packing
+    local class_labels = {}
+    for _, s in ipairs(cls.symbols) do
+      class_labels[#class_labels + 1] = s.label
+    end
 
-    if not res then
-      rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
-      os.exit(1)
+    -- Connect to all shards to ensure complete dump
+    local connections = {}
+    local read_servers = cls.redis_params.read_servers
+    if read_servers then
+      local all_ups = read_servers:all_upstreams()
+      if all_ups and #all_ups > 0 then
+        for _, up in ipairs(all_ups) do
+          local res, conn = connect_to_upstream(up, cls.redis_params)
+          if res then
+            connections[#connections + 1] = { up = up, conn = conn }
+          else
+            rspamd_logger.errx("cannot connect to redis shard %s", up:get_name())
+          end
+        end
+      end
+    end
+
+    -- Fallback: single connection via round-robin
+    if #connections == 0 then
+      local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false)
+      if not res then
+        rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
+        os.exit(1)
+      end
+      connections[#connections + 1] = { conn = conn }
     end
 
     local out = {}
-    local function check_keys(sym)
+    local function check_keys(conn, sym)
       local sym_keys_pattern = string.format("%s_keys", sym)
       conn:add_cmd('SMEMBERS', { sym_keys_pattern })
       local ret, keys = conn:exec()
 
       if not ret then
         rspamd_logger.errx("cannot execute command to get keys: %s", keys)
-        os.exit(1)
+        return
+      end
+
+      if not keys or #keys == 0 then
+        return
       end
 
       if not opts.json then
@@ -355,15 +502,18 @@ local function dump_handler(opts)
               out[#out + 1] = string.format('"%s": %s\n', k,
                   ucl.to_format(redis_map_zip(additional_keys), 'json-compact'))
             end
-            dump_pattern(conn, pat, opts, out, k)
+            dump_pattern(conn, pat, opts, out, k, class_labels)
             patterns_seen[pat] = true
           end
         end
       end
     end
 
-    check_keys(cls.symbol_spam)
-    check_keys(cls.symbol_ham)
+    for _, c in ipairs(connections) do
+      for _, s in ipairs(cls.symbols) do
+        check_keys(c.conn, s.symbol)
+      end
+    end
 
     if #out > 0 then
       dump_out(out, opts, true)
@@ -493,6 +643,220 @@ local function restore_handler(opts)
   end
 end
 
+-- Redis Lua scripts for migration
+local export_script = [[
+local result = redis.call('SCAN', ARGV[1], 'MATCH', ARGV[2], 'COUNT', ARGV[3])
+local cursor = result[1]
+local keys = result[2]
+local data = {}
+local key_names = {}
+for i, k in ipairs(keys) do
+  data[i] = {k, redis.call('HGETALL', k)}
+  key_names[i] = k
+end
+return {cursor, cmsgpack.pack(data), cmsgpack.pack(key_names)}
+]]
+
+local import_script = [[
+local data = cmsgpack.unpack(ARGV[1])
+for _, entry in ipairs(data) do
+  if #entry[2] > 0 then
+    redis.call('HMSET', entry[1], unpack(entry[2]))
+  end
+end
+return #data
+]]
+
+local delete_script = [[
+local keys = cmsgpack.unpack(ARGV[1])
+for _, k in ipairs(keys) do
+  redis.call('DEL', k)
+end
+return #keys
+]]
+
+local function migrate_handler(opts)
+  local stats = {
+    checked = 0,
+    correct = 0,
+    migrated = 0,
+    tokens = 0,
+    errors = 0,
+  }
+
+  for _, cls in ipairs(classifiers) do
+    local write_servers = cls.redis_params.write_servers
+    if not write_servers then
+      rspamd_logger.errx("no write servers configured, cannot migrate")
+      os.exit(1)
+    end
+
+    local all_ups = write_servers:all_upstreams()
+    if not all_ups or #all_ups <= 1 then
+      rspamd_logger.messagex("only %s shard(s) configured, nothing to migrate",
+          all_ups and #all_ups or 0)
+      return
+    end
+
+    rspamd_logger.messagex("found %s shards to check for migration", #all_ups)
+
+    -- Connect to every shard
+    local shard_map = {}
+    for _, up in ipairs(all_ups) do
+      local res, conn = connect_to_upstream(up, cls.redis_params)
+      if not res then
+        rspamd_logger.errx("cannot connect to shard %s, aborting", up:get_name())
+        os.exit(1)
+      end
+      shard_map[#shard_map + 1] = {
+        name = up:get_name(),
+        up = up,
+        conn = conn,
+      }
+    end
+
+    -- Migrate each symbol's keys
+    for _, s in ipairs(cls.symbols) do
+      local sym = s.symbol
+      rspamd_logger.messagex("processing symbol: %s", sym)
+      local sym_keys = string.format("%s_keys", sym)
+
+      for shard_idx, shard in ipairs(shard_map) do
+        shard.conn:add_cmd('SMEMBERS', { sym_keys })
+        local ret, prefixes = shard.conn:exec()
+
+        if not ret then
+          rspamd_logger.errx("cannot get %s from shard %s: %s",
+              sym_keys, shard.name, prefixes)
+          stats.errors = stats.errors + 1
+        elseif prefixes and #prefixes > 0 then
+          rspamd_logger.messagex("  shard %s [%s/%s]: %s prefix key(s) for %s",
+              shard.name, shard_idx, #shard_map, #prefixes, sym)
+
+          for _, prefix in ipairs(prefixes) do
+            stats.checked = stats.checked + 1
+
+            -- Determine which shard this prefix should live on
+            local target_up = write_servers:get_upstream_by_hash(prefix)
+            local target_name = target_up:get_name()
+
+            if target_name == shard.name then
+              -- Already on the correct shard
+              stats.correct = stats.correct + 1
+            else
+              -- Find target connection
+              local target_conn
+              for _, ts in ipairs(shard_map) do
+                if ts.name == target_name then
+                  target_conn = ts.conn
+                  break
+                end
+              end
+
+              if not target_conn then
+                rspamd_logger.errx("    cannot find connection for target shard %s", target_name)
+                stats.errors = stats.errors + 1
+              else
+                rspamd_logger.messagex("    migrating prefix '%s': %s -> %s",
+                    prefix, shard.name, target_name)
+
+                if opts.dry_run then
+                  stats.migrated = stats.migrated + 1
+                else
+                  -- 1. Copy the prefix metadata hash
+                  shard.conn:add_cmd('HGETALL', { prefix })
+                  local hret, hdata = shard.conn:exec()
+
+                  if hret and hdata and #hdata > 0 then
+                    local hmset_args = { prefix }
+                    for _, v in ipairs(hdata) do
+                      hmset_args[#hmset_args + 1] = v
+                    end
+                    target_conn:add_cmd('HMSET', hmset_args)
+                    local mret, merr = target_conn:exec()
+                    if not mret then
+                      rspamd_logger.errx("    failed to copy metadata for %s: %s", prefix, merr)
+                      stats.errors = stats.errors + 1
+                    end
+                  end
+
+                  -- 2. Scan and migrate token keys in batches
+                  local scan_pattern = string.format('%s_*', prefix)
+                  local cursor = "0"
+
+                  repeat
+                    shard.conn:add_cmd('EVAL', {
+                      export_script, '0',
+                      cursor, scan_pattern, tostring(opts.batch_size)
+                    })
+                    local eret, eresults = shard.conn:exec()
+
+                    if not eret then
+                      rspamd_logger.errx("    export script failed for %s: %s", prefix, eresults)
+                      stats.errors = stats.errors + 1
+                      break
+                    end
+
+                    cursor = eresults[1]
+                    local packed_data = eresults[2]
+                    local packed_keys = eresults[3]
+
+                    -- Import to target
+                    if packed_data and #packed_data > 0 then
+                      target_conn:add_cmd('EVAL', {
+                        import_script, '0', packed_data
+                      })
+                      local iret, ires = target_conn:exec()
+
+                      if not iret then
+                        rspamd_logger.errx("    import script failed for %s: %s", prefix, ires)
+                        stats.errors = stats.errors + 1
+                      else
+                        stats.tokens = stats.tokens + (tonumber(ires) or 0)
+                      end
+                    end
+
+                    -- Delete from source (unless --no-delete)
+                    if not opts.no_delete and packed_keys and #packed_keys > 0 then
+                      shard.conn:add_cmd('EVAL', {
+                        delete_script, '0', packed_keys
+                      })
+                      local dret, derr = shard.conn:exec()
+                      if not dret then
+                        rspamd_logger.errx("    delete script failed for %s: %s", prefix, derr)
+                        stats.errors = stats.errors + 1
+                      end
+                    end
+                  until cursor == "0"
+
+                  -- 3. Update _keys sets
+                  target_conn:add_cmd('SADD', { sym_keys, prefix })
+                  target_conn:exec()
+
+                  shard.conn:add_cmd('SREM', { sym_keys, prefix })
+                  shard.conn:exec()
+
+                  -- 4. Delete source prefix hash (unless --no-delete)
+                  if not opts.no_delete then
+                    shard.conn:add_cmd('DEL', { prefix })
+                    shard.conn:exec()
+                  end
+
+                  stats.migrated = stats.migrated + 1
+                end
+              end
+            end
+          end
+        end
+      end
+    end
+  end
+
+  rspamd_logger.messagex("migration %s: checked=%s correct=%s migrated=%s tokens=%s errors=%s",
+      opts.dry_run and "dry-run complete" or "complete",
+      stats.checked, stats.correct, stats.migrated, stats.tokens, stats.errors)
+end
+
 local function handler(args)
   local opts = parser:parse(args)
 
@@ -544,6 +908,8 @@ local function handler(args)
     dump_handler(opts)
   elseif command == 'restore' then
     restore_handler(opts)
+  elseif command == 'migrate' then
+    migrate_handler(opts)
   else
     parser:error('command %s is not implemented', command)
   end