]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Merge branch 'master' into patch-16
authorDmitriy Alekseev <1865999+dragoangel@users.noreply.github.com>
Thu, 22 Jan 2026 21:02:27 +0000 (22:02 +0100)
committerGitHub <noreply@github.com>
Thu, 22 Jan 2026 21:02:27 +0000 (22:02 +0100)
1  2 
src/plugins/lua/neural.lua

index 48696145115303faa77ec5e9597a7b90fc5c6cc0,ad0ef94ce68c2c89bff7cbdafc1b4fe870e6148a..c585d052da76d9057c7afe925fe0f3d521c0f06e
@@@ -764,66 -850,10 +850,69 @@@ local function maybe_train_existing_ann
    if sel_elt then
      -- We have our ANN and that's train vectors, check if we can learn
      local ann_key = sel_elt.redis_key
+     local pending_key = neural_common.pending_train_key(rule, set)
  
 +    lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
 +      ann_key)
 +
 +    -- Create continuation closure
 +    local redis_len_cb_gen = function(cont_cb, what, is_final)
 +      return function(err, data)
 +        if err then
 +          rspamd_logger.errx(rspamd_config,
 +            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
 +        elseif data and type(data) == 'number' or type(data) == 'string' then
 +          local ntrains = tonumber(data) or 0
 +          lens[what] = ntrains
 +          if is_final then
 +            -- Ensure that we have the following:
 +            -- one class has reached max_trains
 +            -- other class(es) are at least as full as classes_bias
 +            -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
 +            -- one class must have 10 or more trains whilst another should have
 +            -- at least (10 * (1 - 0.25)) = 8 trains
 +
 +            local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
 +            local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
 +
 +            if rule.train.learn_mode == 'balanced' then
 +              local len_bias_check_pred = function(_, l)
 +                return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
 +              end
 +              if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
 +                lua_util.debugm(N, rspamd_config,
 +                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
 +                  ann_key, lens, rule.train.max_trains, what)
 +                cont_cb()
 +              else
 +                lua_util.debugm(N, rspamd_config,
 +                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
 +                  ann_key, what, lens, rule.train.max_trains)
 +              end
 +            else
 +              -- Probabilistic mode, just ensure that at least one vector is okay
 +              if min_len > 0 and max_len >= rule.train.max_trains then
 +                lua_util.debugm(N, rspamd_config,
 +                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
 +                  ann_key, lens, rule.train.max_trains, what)
 +                cont_cb()
 +              else
 +                lua_util.debugm(N, rspamd_config,
 +                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
 +                  ann_key, what, lens, rule.train.max_trains)
 +              end
 +            end
 +          else
 +            lua_util.debugm(N, rspamd_config,
 +              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
 +              what, ann_key, ntrains, rule.train.max_trains)
 +            cont_cb()
 +          end
 +        end
 +      end
 +    end
+     lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
+       ann_key, pending_key)
  
      local function initiate_train()
        rspamd_logger.infox(rspamd_config,