From: Vsevolod Stakhov Date: Fri, 27 Feb 2026 14:25:25 +0000 (+0000) Subject: [Feature] Add shard migration and multi-class support to statistics_dump X-Git-Tag: 4.0.0~24^2~8 X-Git-Url: http://git.ipfire.org/gitweb/?a=commitdiff_plain;h=36325c5c573e3fc8ebcdd40eeac5390cd90f2ee3;p=thirdparty%2Frspamd.git [Feature] Add shard migration and multi-class support to statistics_dump Add `rspamadm statistics_dump migrate` subcommand for migrating per-user Bayes data between Redis shards after the Jump Hash to Ketama transition. The tool scans all shards, identifies misplaced prefixes via get_upstream_by_hash, and moves them in batches using Redis Lua scripts. Also fix multi-class Bayes support: dump/restore now handles arbitrary statfile classes (not just binary spam/ham) by collecting all symbols from classifier config with proper class label mapping. The dump command now iterates all shards via all_upstreams() for complete data export. --- diff --git a/lualib/rspamadm/statistics_dump.lua b/lualib/rspamadm/statistics_dump.lua index e727c739f7..ac7f7d0796 100644 --- a/lualib/rspamadm/statistics_dump.lua +++ b/lualib/rspamadm/statistics_dump.lua @@ -90,6 +90,19 @@ restore:option "-m --mode" restore:flag "-n --no-operation" :description "Only show redis commands to be issued" +-- Migrate +local migrate = parser:command "migrate m" + :description "Migrate bayes data between shards (after hash algorithm change)" +migrate:flag "-n --dry-run" + :description "Only show what would be migrated, without writing" +migrate:flag "--no-delete" + :description "Copy keys to target shard without deleting from source" +migrate:option "-b --batch-size" + :description "Number of entries to process per SCAN batch" + :argname("") + :convert(tonumber) + :default(1000) + local function load_config(opts) local _r, err = rspamd_config:load_ucl(opts['config']) @@ -109,27 +122,60 @@ local function check_redis_classifier(cls, cfg) -- Skip old classifiers if cls.new_schema then local symbol_spam, symbol_ham + local symbols = {} -- Load symbols from statfiles + local function get_class_label(class_name) + -- Check class_labels mapping in classifier config + if cls.class_labels and class_name then + local label = cls.class_labels[class_name] + if label then + return label + end + end + -- Default mapping: spam→S, ham→H, custom→class_name + if class_name == 'spam' then + return 'S' + elseif class_name == 'ham' then + return 'H' + end + return class_name + end + local function check_statfile_table(tbl, def_sym) local symbol = tbl.symbol or def_sym - local spam - if tbl.spam then - spam = tbl.spam + -- Determine class_name by priority: + -- 1. Explicit tbl.class + -- 2. Legacy tbl.spam boolean + -- 3. Heuristic from symbol name + local class_name + if tbl.class then + class_name = tbl.class + elseif tbl.spam then + class_name = 'spam' else if string.match(symbol:upper(), 'SPAM') then - spam = true + class_name = 'spam' else - spam = false + class_name = 'ham' end end - if spam then + local label = get_class_label(class_name) + + -- Backward compat for binary classifiers + if class_name == 'spam' then symbol_spam = symbol - else + elseif class_name == 'ham' then symbol_ham = symbol end + + table.insert(symbols, { + symbol = symbol, + class_name = class_name, + label = label, + }) end local statfiles = cls.statfile @@ -174,6 +220,7 @@ local function check_redis_classifier(cls, cfg) table.insert(classifiers, { symbol_spam = symbol_spam, symbol_ham = symbol_ham, + symbols = symbols, redis_params = redis_params, }) end @@ -196,6 +243,54 @@ local clear_fcn = table.clear or function(tbl) end end +local function connect_to_upstream(up, redis_params) + local rspamd_redis = require "rspamd_redis" + local ret, conn = rspamd_redis.connect_sync({ + host = up:get_addr(), + timeout = redis_params.timeout, + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + }) + + if not ret or not conn then + rspamd_logger.errx("cannot connect to redis %s: %s", up:get_name(), conn) + return false, nil + end + + local need_exec = false + if redis_params.username then + if redis_params.password then + conn:add_cmd('AUTH', { redis_params.username, redis_params.password }) + need_exec = true + else + rspamd_logger.errx("redis requires a password when username is supplied") + return false, nil + end + elseif redis_params.password then + conn:add_cmd('AUTH', { redis_params.password }) + need_exec = true + end + + if redis_params.db then + conn:add_cmd('SELECT', { tostring(redis_params.db) }) + need_exec = true + elseif redis_params.dbname then + conn:add_cmd('SELECT', { tostring(redis_params.dbname) }) + need_exec = true + end + + if need_exec then + local exec_ret, res = conn:exec() + if not exec_ret then + rspamd_logger.errx("cannot authenticate/select db on %s: %s", up:get_name(), res) + return false, nil + end + end + + return true, conn +end + local compress_ctx local function dump_out(out, opts, last) @@ -216,14 +311,26 @@ local function dump_out(out, opts, last) end end -local function dump_cdb(out, opts, last, pattern) +local function dump_cdb(out, opts, last, pattern, class_labels) local results = out[pattern] if not out.cdb_builder then -- First invocation out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern)) - out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0')) - out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0')) + -- Write learned counts for all class labels + for _, lbl in ipairs(class_labels or { 'S', 'H' }) do + local learned_key + if lbl == 'S' then + learned_key = 'learns_spam' + elseif lbl == 'H' then + learned_key = 'learns_ham' + else + learned_key = 'learns_' .. lbl + end + -- Pad CDB key to 8 bytes for consistent lookup + local cdb_key = string.format('_lrn%-4s', lbl) + out.cdb_builder:add(cdb_key, rspamd_i64.fromstring(results[learned_key] or '0')) + end end for _, o in ipairs(results.elts) do @@ -236,9 +343,15 @@ local function dump_cdb(out, opts, last, pattern) end end -local function dump_pattern(conn, pattern, opts, out, key) +local function dump_pattern(conn, pattern, opts, out, key, class_labels) local cursor = 0 + -- Build CDB pack format string from class labels + local cdb_fmt + if opts.cdb then + cdb_fmt = string.rep('f', #class_labels) + end + repeat conn:add_cmd('SCAN', { tostring(cursor), 'MATCH', pattern, @@ -277,10 +390,14 @@ local function dump_pattern(conn, pattern, opts, out, key) for i, d in ipairs(tokens) do if cursor == 0 and i == #tokens or not opts.json then if opts.cdb then + -- Pack all class label values dynamically + local values = {} + for _, lbl in ipairs(class_labels) do + values[#values + 1] = tonumber(d.data[lbl] or '0') or 0 + end table.insert(out[key].elts, { key = rspamd_i64.fromstring(string.match(d.key, '%d+')), - value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0, - tonumber(d.data["H"] or '0')) + value = rspamd_util.pack(cdb_fmt, lua_util.unpack(values)) }) else out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key, @@ -300,14 +417,14 @@ local function dump_pattern(conn, pattern, opts, out, key) -- Do not write the last chunk of out as it will be processed afterwards if cursor ~= 0 then if opts.cdb then - dump_cdb(out, opts, false, key) + dump_cdb(out, opts, false, key, class_labels) out[key].elts = {} else dump_out(out, opts, false) clear_fcn(out) end elseif opts.cdb then - dump_cdb(out, opts, true, key) + dump_cdb(out, opts, true, key, class_labels) end until cursor == 0 @@ -316,22 +433,52 @@ end local function dump_handler(opts) local patterns_seen = {} for _, cls in ipairs(classifiers) do - local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false) + -- Collect class labels for CDB packing + local class_labels = {} + for _, s in ipairs(cls.symbols) do + class_labels[#class_labels + 1] = s.label + end - if not res then - rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) - os.exit(1) + -- Connect to all shards to ensure complete dump + local connections = {} + local read_servers = cls.redis_params.read_servers + if read_servers then + local all_ups = read_servers:all_upstreams() + if all_ups and #all_ups > 0 then + for _, up in ipairs(all_ups) do + local res, conn = connect_to_upstream(up, cls.redis_params) + if res then + connections[#connections + 1] = { up = up, conn = conn } + else + rspamd_logger.errx("cannot connect to redis shard %s", up:get_name()) + end + end + end + end + + -- Fallback: single connection via round-robin + if #connections == 0 then + local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false) + if not res then + rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) + os.exit(1) + end + connections[#connections + 1] = { conn = conn } end local out = {} - local function check_keys(sym) + local function check_keys(conn, sym) local sym_keys_pattern = string.format("%s_keys", sym) conn:add_cmd('SMEMBERS', { sym_keys_pattern }) local ret, keys = conn:exec() if not ret then rspamd_logger.errx("cannot execute command to get keys: %s", keys) - os.exit(1) + return + end + + if not keys or #keys == 0 then + return end if not opts.json then @@ -355,15 +502,18 @@ local function dump_handler(opts) out[#out + 1] = string.format('"%s": %s\n', k, ucl.to_format(redis_map_zip(additional_keys), 'json-compact')) end - dump_pattern(conn, pat, opts, out, k) + dump_pattern(conn, pat, opts, out, k, class_labels) patterns_seen[pat] = true end end end end - check_keys(cls.symbol_spam) - check_keys(cls.symbol_ham) + for _, c in ipairs(connections) do + for _, s in ipairs(cls.symbols) do + check_keys(c.conn, s.symbol) + end + end if #out > 0 then dump_out(out, opts, true) @@ -493,6 +643,220 @@ local function restore_handler(opts) end end +-- Redis Lua scripts for migration +local export_script = [[ +local result = redis.call('SCAN', ARGV[1], 'MATCH', ARGV[2], 'COUNT', ARGV[3]) +local cursor = result[1] +local keys = result[2] +local data = {} +local key_names = {} +for i, k in ipairs(keys) do + data[i] = {k, redis.call('HGETALL', k)} + key_names[i] = k +end +return {cursor, cmsgpack.pack(data), cmsgpack.pack(key_names)} +]] + +local import_script = [[ +local data = cmsgpack.unpack(ARGV[1]) +for _, entry in ipairs(data) do + if #entry[2] > 0 then + redis.call('HMSET', entry[1], unpack(entry[2])) + end +end +return #data +]] + +local delete_script = [[ +local keys = cmsgpack.unpack(ARGV[1]) +for _, k in ipairs(keys) do + redis.call('DEL', k) +end +return #keys +]] + +local function migrate_handler(opts) + local stats = { + checked = 0, + correct = 0, + migrated = 0, + tokens = 0, + errors = 0, + } + + for _, cls in ipairs(classifiers) do + local write_servers = cls.redis_params.write_servers + if not write_servers then + rspamd_logger.errx("no write servers configured, cannot migrate") + os.exit(1) + end + + local all_ups = write_servers:all_upstreams() + if not all_ups or #all_ups <= 1 then + rspamd_logger.messagex("only %s shard(s) configured, nothing to migrate", + all_ups and #all_ups or 0) + return + end + + rspamd_logger.messagex("found %s shards to check for migration", #all_ups) + + -- Connect to every shard + local shard_map = {} + for _, up in ipairs(all_ups) do + local res, conn = connect_to_upstream(up, cls.redis_params) + if not res then + rspamd_logger.errx("cannot connect to shard %s, aborting", up:get_name()) + os.exit(1) + end + shard_map[#shard_map + 1] = { + name = up:get_name(), + up = up, + conn = conn, + } + end + + -- Migrate each symbol's keys + for _, s in ipairs(cls.symbols) do + local sym = s.symbol + rspamd_logger.messagex("processing symbol: %s", sym) + local sym_keys = string.format("%s_keys", sym) + + for shard_idx, shard in ipairs(shard_map) do + shard.conn:add_cmd('SMEMBERS', { sym_keys }) + local ret, prefixes = shard.conn:exec() + + if not ret then + rspamd_logger.errx("cannot get %s from shard %s: %s", + sym_keys, shard.name, prefixes) + stats.errors = stats.errors + 1 + elseif prefixes and #prefixes > 0 then + rspamd_logger.messagex(" shard %s [%s/%s]: %s prefix key(s) for %s", + shard.name, shard_idx, #shard_map, #prefixes, sym) + + for _, prefix in ipairs(prefixes) do + stats.checked = stats.checked + 1 + + -- Determine which shard this prefix should live on + local target_up = write_servers:get_upstream_by_hash(prefix) + local target_name = target_up:get_name() + + if target_name == shard.name then + -- Already on the correct shard + stats.correct = stats.correct + 1 + else + -- Find target connection + local target_conn + for _, ts in ipairs(shard_map) do + if ts.name == target_name then + target_conn = ts.conn + break + end + end + + if not target_conn then + rspamd_logger.errx(" cannot find connection for target shard %s", target_name) + stats.errors = stats.errors + 1 + else + rspamd_logger.messagex(" migrating prefix '%s': %s -> %s", + prefix, shard.name, target_name) + + if opts.dry_run then + stats.migrated = stats.migrated + 1 + else + -- 1. Copy the prefix metadata hash + shard.conn:add_cmd('HGETALL', { prefix }) + local hret, hdata = shard.conn:exec() + + if hret and hdata and #hdata > 0 then + local hmset_args = { prefix } + for _, v in ipairs(hdata) do + hmset_args[#hmset_args + 1] = v + end + target_conn:add_cmd('HMSET', hmset_args) + local mret, merr = target_conn:exec() + if not mret then + rspamd_logger.errx(" failed to copy metadata for %s: %s", prefix, merr) + stats.errors = stats.errors + 1 + end + end + + -- 2. Scan and migrate token keys in batches + local scan_pattern = string.format('%s_*', prefix) + local cursor = "0" + + repeat + shard.conn:add_cmd('EVAL', { + export_script, '0', + cursor, scan_pattern, tostring(opts.batch_size) + }) + local eret, eresults = shard.conn:exec() + + if not eret then + rspamd_logger.errx(" export script failed for %s: %s", prefix, eresults) + stats.errors = stats.errors + 1 + break + end + + cursor = eresults[1] + local packed_data = eresults[2] + local packed_keys = eresults[3] + + -- Import to target + if packed_data and #packed_data > 0 then + target_conn:add_cmd('EVAL', { + import_script, '0', packed_data + }) + local iret, ires = target_conn:exec() + + if not iret then + rspamd_logger.errx(" import script failed for %s: %s", prefix, ires) + stats.errors = stats.errors + 1 + else + stats.tokens = stats.tokens + (tonumber(ires) or 0) + end + end + + -- Delete from source (unless --no-delete) + if not opts.no_delete and packed_keys and #packed_keys > 0 then + shard.conn:add_cmd('EVAL', { + delete_script, '0', packed_keys + }) + local dret, derr = shard.conn:exec() + if not dret then + rspamd_logger.errx(" delete script failed for %s: %s", prefix, derr) + stats.errors = stats.errors + 1 + end + end + until cursor == "0" + + -- 3. Update _keys sets + target_conn:add_cmd('SADD', { sym_keys, prefix }) + target_conn:exec() + + shard.conn:add_cmd('SREM', { sym_keys, prefix }) + shard.conn:exec() + + -- 4. Delete source prefix hash (unless --no-delete) + if not opts.no_delete then + shard.conn:add_cmd('DEL', { prefix }) + shard.conn:exec() + end + + stats.migrated = stats.migrated + 1 + end + end + end + end + end + end + end + end + + rspamd_logger.messagex("migration %s: checked=%s correct=%s migrated=%s tokens=%s errors=%s", + opts.dry_run and "dry-run complete" or "complete", + stats.checked, stats.correct, stats.migrated, stats.tokens, stats.errors) +end + local function handler(args) local opts = parser:parse(args) @@ -544,6 +908,8 @@ local function handler(args) dump_handler(opts) elseif command == 'restore' then restore_handler(opts) + elseif command == 'migrate' then + migrate_handler(opts) else parser:error('command %s is not implemented', command) end