]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Speed up shard migration and harden restore batching
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 5 Mar 2026 22:38:26 +0000 (22:38 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 5 Mar 2026 22:38:26 +0000 (22:38 +0000)
lualib/rspamadm/statistics_dump.lua

index fec1adbabbf1f09425ae67c5f681297448f77aec..4068eda3f5e778f54cde2e1cbbd8fa4458557303 100644 (file)
@@ -329,6 +329,9 @@ end
 -- Maximum commands per pipeline exec() to avoid Lua stack overflow
 local pipeline_max = 1000
 
+local append_redis_hash_hmset
+local exec_redis_commands
+
 local function dump_cdb(out, opts, last, pattern, class_labels)
   local results = out[pattern]
 
@@ -620,6 +623,30 @@ local function obj_to_redis_arguments(obj, opts, cmd_pipe)
   return cmd_pipe
 end
 
+local function estimate_redis_commands(obj, opts)
+  local key, value = next(obj)
+
+  if type(key) ~= 'string' or type(value) ~= 'table' then
+    return 0
+  end
+
+  if not value[1] then
+    local n = 0
+
+    if opts.mode == 'replace' then
+      return 1
+    end
+
+    for _ in pairs(value) do
+      n = n + 1
+    end
+
+    return n
+  end
+
+  return #value
+end
+
 local function execute_batch(batch, conns, opts)
   local cmd_pipe = {}
 
@@ -636,25 +663,54 @@ local function execute_batch(batch, conns, opts)
       -- Chunk commands to avoid stack overflow on large datasets
       for i = 1, #cmd_pipe, pipeline_max do
         local chunk_end = math.min(i + pipeline_max - 1, #cmd_pipe)
+        local added = 0
+
         for j = i, chunk_end do
           local is_ok, err = conn:add_cmd(cmd_pipe[j][1], cmd_pipe[j][2])
 
           if not is_ok then
             rspamd_logger.errx("cannot add command: %s with args: %s: %s",
                 cmd_pipe[j][1], cmd_pipe[j][2], err)
+            return false, err
           end
+
+          added = added + 1
         end
 
-        conn:exec()
+        if added > 0 then
+          local ret, err = conn:exec()
+
+          if not ret then
+            rspamd_logger.errx("cannot execute restore batch: %s", err)
+            return false, err
+          end
+        end
       end
     end
   end
+
+  return true
+end
+
+local function flush_restore_batch(batch, conns, opts)
+  if #batch == 0 then
+    return true
+  end
+
+  local ok = execute_batch(batch, conns, opts)
+  if not ok then
+    return false
+  end
+
+  clear_fcn(batch)
+  return true
 end
 
 local function restore_handler(opts)
   local selected = select_classifier(opts)
   local files = opts.file or { '-' }
   local conns = {}
+  local restore_pipeline_limit = math.max(100, math.min(opts.batch_size, pipeline_max))
 
   for _, cls in ipairs(selected) do
     local res, conn = lua_redis.redis_connect_sync(cls.redis_params, true)
@@ -668,6 +724,7 @@ local function restore_handler(opts)
   end
 
   local batch = {}
+  local pending_cmds = 0
 
   for _, f in ipairs(files) do
     local fd
@@ -688,11 +745,15 @@ local function restore_handler(opts)
       end
 
       table.insert(batch, ucl_parser:get_object())
+      pending_cmds = pending_cmds + estimate_redis_commands(batch[#batch], opts)
       cur_line = cur_line + 1
 
-      if cur_line % opts.batch_size == 0 then
-        execute_batch(batch, conns, opts)
-        batch = {}
+      if #batch >= opts.batch_size or pending_cmds >= restore_pipeline_limit then
+        local ok = flush_restore_batch(batch, conns, opts)
+        if not ok then
+          os.exit(1)
+        end
+        pending_cmds = 0
 
         if cur_line % (opts.batch_size * 10) == 0 then
           collectgarbage('collect')
@@ -707,70 +768,222 @@ local function restore_handler(opts)
   end
 
   if #batch > 0 then
-    execute_batch(batch, conns, opts)
+    local ok = flush_restore_batch(batch, conns, opts)
+    if not ok then
+      os.exit(1)
+    end
   end
 end
 
 -- Migrate a single prefix's token keys from source to target using pipelined commands.
 -- SCAN on source, pipeline HGETALL, pipeline HMSET to target, pipeline DEL on source.
 -- Returns number of tokens migrated.
-local function migrate_prefix_tokens(src_conn, dst_conn, prefix, batch_size, no_delete)
-  local scan_pattern = string.format('%s_*', prefix)
-  local cursor = "0"
+local function collect_prefix_token_keys(src_conn, prefixes, batch_size)
+  local keys = {}
+  local seen = {}
+
+  for _, prefix in ipairs(prefixes) do
+    local scan_pattern = string.format('%s_*', prefix)
+    local cursor = "0"
+
+    repeat
+      src_conn:add_cmd('SCAN', { cursor, 'MATCH', scan_pattern,
+                                 'COUNT', tostring(batch_size) })
+      local ret, results = src_conn:exec()
+
+      if not ret then
+        rspamd_logger.errx("SCAN failed for %s: %s", prefix, results)
+        return nil, true
+      end
+
+      cursor = results[1]
+      local scanned = results[2]
+
+      if scanned and #scanned > 0 then
+        for _, k in ipairs(scanned) do
+          if not seen[k] then
+            seen[k] = true
+            keys[#keys + 1] = k
+          end
+        end
+      end
+    until cursor == "0"
+  end
+
+  return keys, false
+end
+
+local function migrate_token_keys(src_conn, dst_conn, keys, no_delete)
   local total_tokens = 0
 
-  repeat
-    src_conn:add_cmd('SCAN', { cursor, 'MATCH', scan_pattern,
-                               'COUNT', tostring(batch_size) })
-    local ret, results = src_conn:exec()
+  for i = 1, #keys, pipeline_max do
+    local chunk_end = math.min(i + pipeline_max - 1, #keys)
 
-    if not ret then
-      rspamd_logger.errx("SCAN failed for %s: %s", prefix, results)
+    for j = i, chunk_end do
+      src_conn:add_cmd('HGETALL', { keys[j] })
+    end
+
+    local all_results = { src_conn:exec() }
+    local dst_cmds = {}
+    local src_del_cmds = {}
+
+    for j = i, chunk_end do
+      local idx = (j - i) * 2 + 1
+      local hret, hdata = all_results[idx], all_results[idx + 1]
+
+      if hret and append_redis_hash_hmset(dst_cmds, keys[j], hdata) then
+        total_tokens = total_tokens + 1
+        if not no_delete then
+          src_del_cmds[#src_del_cmds + 1] = { 'DEL', { keys[j] } }
+        end
+      end
+    end
+
+    all_results = nil
+
+    if not exec_redis_commands(dst_conn, dst_cmds) then
       return total_tokens, true
     end
 
-    cursor = results[1]
-    local keys = results[2]
+    if not no_delete and not exec_redis_commands(src_conn, src_del_cmds) then
+      return total_tokens, true
+    end
+  end
 
-    if keys and #keys > 0 then
-      -- Pipeline HGETALL on source for this batch
-      for _, k in ipairs(keys) do
-        src_conn:add_cmd('HGETALL', { k })
+  return total_tokens, false
+end
+
+append_redis_hash_hmset = function(cmds, key, hash_data)
+  if hash_data and #hash_data > 0 then
+    local args = { key }
+    for _, v in ipairs(hash_data) do
+      args[#args + 1] = v
+    end
+    cmds[#cmds + 1] = { 'HMSET', args }
+    return true
+  end
+
+  return false
+end
+
+exec_redis_commands = function(conn, cmds)
+  if #cmds == 0 then
+    return true
+  end
+
+  for i = 1, #cmds, pipeline_max do
+    local chunk_end = math.min(i + pipeline_max - 1, #cmds)
+
+    for j = i, chunk_end do
+      local is_ok, err = conn:add_cmd(cmds[j][1], cmds[j][2])
+
+      if not is_ok then
+        rspamd_logger.errx("cannot add command: %s with args: %s: %s",
+            cmds[j][1], cmds[j][2], err)
+        return false
       end
-      local all_results = { src_conn:exec() }
+    end
 
-      -- Pipeline HMSET on target
-      local imported = 0
-      for i = 1, #all_results, 2 do
-        local r, hash_data = all_results[i], all_results[i + 1]
-        if r and hash_data and #hash_data > 0 then
-          local args = { keys[(i + 1) / 2] }
-          for _, v in ipairs(hash_data) do
-            args[#args + 1] = v
-          end
-          dst_conn:add_cmd('HMSET', args)
-          imported = imported + 1
+    local ret, err = conn:exec()
+    if not ret then
+      rspamd_logger.errx("cannot execute redis pipeline: %s", err)
+      return false
+    end
+  end
+
+  return true
+end
+
+local function migrate_prefix_group(prefixes, src_conn, dst_conn, sym_keys, batch_size, no_delete)
+  local stats = {
+    migrated = 0,
+    tokens = 0,
+    errors = 0,
+  }
+
+  if #prefixes == 0 then
+    return stats
+  end
+
+  for i = 1, #prefixes, pipeline_max do
+    local chunk_end = math.min(i + pipeline_max - 1, #prefixes)
+
+    for j = i, chunk_end do
+      src_conn:add_cmd('HGETALL', { prefixes[j] })
+    end
+
+    local all_results = { src_conn:exec() }
+    local dst_meta_cmds = {}
+    local dst_keys_cmds = {}
+    local src_keys_cmds = {}
+    local src_meta_del_cmds = {}
+
+    for j = i, chunk_end do
+      local idx = (j - i) * 2 + 1
+      local prefix = prefixes[j]
+      local hret, hdata = all_results[idx], all_results[idx + 1]
+
+      if hret then
+        append_redis_hash_hmset(dst_meta_cmds, prefix, hdata)
+        dst_keys_cmds[#dst_keys_cmds + 1] = { 'SADD', { sym_keys, prefix } }
+        if not no_delete then
+          src_keys_cmds[#src_keys_cmds + 1] = { 'SREM', { sym_keys, prefix } }
+          src_meta_del_cmds[#src_meta_del_cmds + 1] = { 'DEL', { prefix } }
         end
+        stats.migrated = stats.migrated + 1
+      else
+        rspamd_logger.errx("cannot get prefix metadata for %s", prefix)
+        stats.errors = stats.errors + 1
       end
-      all_results = nil -- release memory
+    end
+
+    all_results = nil
+
+    if not exec_redis_commands(dst_conn, dst_meta_cmds) then
+      stats.errors = stats.errors + (chunk_end - i + 1)
+      return stats
+    end
+
+    local chunk_prefixes = {}
+    for j = i, chunk_end do
+      chunk_prefixes[#chunk_prefixes + 1] = prefixes[j]
+    end
+
+    local token_keys, scan_error = collect_prefix_token_keys(src_conn, chunk_prefixes, batch_size)
+    if scan_error then
+      stats.errors = stats.errors + #chunk_prefixes
+      return stats
+    end
 
-      if imported > 0 then
-        dst_conn:exec()
+    if token_keys and #token_keys > 0 then
+      local tok_count, had_error = migrate_token_keys(src_conn, dst_conn, token_keys, no_delete)
+      stats.tokens = stats.tokens + tok_count
+
+      if had_error then
+        stats.errors = stats.errors + #chunk_prefixes
+        return stats
       end
+    end
 
-      -- Pipeline DEL on source
-      if not no_delete then
-        for _, k in ipairs(keys) do
-          src_conn:add_cmd('DEL', { k })
-        end
-        src_conn:exec()
+    if not exec_redis_commands(dst_conn, dst_keys_cmds) then
+      stats.errors = stats.errors + (chunk_end - i + 1)
+      return stats
+    end
+
+    if not no_delete then
+      if not exec_redis_commands(src_conn, src_keys_cmds) then
+        stats.errors = stats.errors + (chunk_end - i + 1)
+        return stats
       end
 
-      total_tokens = total_tokens + imported
+      if not exec_redis_commands(src_conn, src_meta_del_cmds) then
+        stats.errors = stats.errors + (chunk_end - i + 1)
+        return stats
+      end
     end
-  until cursor == "0"
+  end
 
-  return total_tokens, false
+  return stats
 end
 
 local function migrate_handler(opts)
@@ -865,7 +1078,9 @@ local function migrate_handler(opts)
         rspamd_logger.messagex("  %s prefix(es) need migration", #misplaced)
       end
 
-      -- Phase 2: Migrate misplaced prefixes
+      -- Phase 2: Migrate misplaced prefixes grouped by shard pair to reduce round-trips
+      local grouped = {}
+
       for pi, m in ipairs(misplaced) do
         if not m.dst then
           rspamd_logger.errx("    cannot find target shard for prefix '%s'", m.prefix)
@@ -874,53 +1089,41 @@ local function migrate_handler(opts)
           rspamd_logger.messagex("    [%s/%s] '%s': %s -> %s",
               pi, #misplaced, m.prefix, m.src.name, m.dst.name)
 
-          if not opts.dry_run then
-            -- Copy prefix metadata hash
-            m.src.conn:add_cmd('HGETALL', { m.prefix })
-            local hret, hdata = m.src.conn:exec()
-
-            if hret and hdata and #hdata > 0 then
-              local hmset_args = { m.prefix }
-              for _, v in ipairs(hdata) do
-                hmset_args[#hmset_args + 1] = v
-              end
-              m.dst.conn:add_cmd('HMSET', hmset_args)
-              m.dst.conn:exec()
-              hmset_args = nil
-            end
-            hdata = nil
-
-            -- Migrate token keys in pipelined batches
-            local tok_count, had_error = migrate_prefix_tokens(
-                m.src.conn, m.dst.conn, m.prefix, opts.batch_size, opts.no_delete)
-            stats.tokens = stats.tokens + tok_count
+          stats.migrated = stats.migrated + 1
 
-            if had_error then
-              stats.errors = stats.errors + 1
+          if not opts.dry_run then
+            local group_key = string.format('%s\0%s', m.src.name, m.dst.name)
+            local group = grouped[group_key]
+
+            if not group then
+              group = {
+                src = m.src,
+                dst = m.dst,
+                prefixes = {},
+              }
+              grouped[group_key] = group
             end
 
-            -- Update _keys sets
-            m.dst.conn:add_cmd('SADD', { sym_keys, m.prefix })
-            m.dst.conn:exec()
-            m.src.conn:add_cmd('SREM', { sym_keys, m.prefix })
-            m.src.conn:exec()
-
-            -- Delete source prefix hash
-            if not opts.no_delete then
-              m.src.conn:add_cmd('DEL', { m.prefix })
-              m.src.conn:exec()
-            end
+            group.prefixes[#group.prefixes + 1] = m.prefix
           end
-
-          stats.migrated = stats.migrated + 1
         end
 
-        -- Periodic GC to prevent memory bloat
         if pi % 100 == 0 then
           collectgarbage('collect')
         end
       end
 
+      if not opts.dry_run then
+        for _, group in pairs(grouped) do
+          rspamd_logger.messagex("  migrating %s prefix(es): %s -> %s",
+              #group.prefixes, group.src.name, group.dst.name)
+          local group_stats = migrate_prefix_group(group.prefixes,
+              group.src.conn, group.dst.conn, sym_keys, opts.batch_size, opts.no_delete)
+          stats.tokens = stats.tokens + group_stats.tokens
+          stats.errors = stats.errors + group_stats.errors
+        end
+      end
+
       misplaced = nil
       collectgarbage('collect')
     end