]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
fix after master merge 5855/head
authorDmitriy Alekseev <1865999+dragoangel@users.noreply.github.com>
Thu, 22 Jan 2026 21:09:47 +0000 (22:09 +0100)
committerGitHub <noreply@github.com>
Thu, 22 Jan 2026 21:09:47 +0000 (22:09 +0100)
src/plugins/lua/neural.lua

index c585d052da76d9057c7afe925fe0f3d521c0f06e..49aab5af7038d7f170462b233de0b834e3d32233 100644 (file)
@@ -852,65 +852,6 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     local ann_key = sel_elt.redis_key
     local pending_key = neural_common.pending_train_key(rule, set)
 
-    lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
-      ann_key)
-
-    -- Create continuation closure
-    local redis_len_cb_gen = function(cont_cb, what, is_final)
-      return function(err, data)
-        if err then
-          rspamd_logger.errx(rspamd_config,
-            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
-        elseif data and type(data) == 'number' or type(data) == 'string' then
-          local ntrains = tonumber(data) or 0
-          lens[what] = ntrains
-          if is_final then
-            -- Ensure that we have the following:
-            -- one class has reached max_trains
-            -- other class(es) are at least as full as classes_bias
-            -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
-            -- one class must have 10 or more trains whilst another should have
-            -- at least (10 * (1 - 0.25)) = 8 trains
-
-            local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
-            local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
-
-            if rule.train.learn_mode == 'balanced' then
-              local len_bias_check_pred = function(_, l)
-                return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
-              end
-              if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
-                lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
-                cont_cb()
-              else
-                lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
-              end
-            else
-              -- Probabilistic mode, just ensure that at least one vector is okay
-              if min_len > 0 and max_len >= rule.train.max_trains then
-                lua_util.debugm(N, rspamd_config,
-                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                  ann_key, lens, rule.train.max_trains, what)
-                cont_cb()
-              else
-                lua_util.debugm(N, rspamd_config,
-                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                  ann_key, what, lens, rule.train.max_trains)
-              end
-            end
-          else
-            lua_util.debugm(N, rspamd_config,
-              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-              what, ann_key, ntrains, rule.train.max_trains)
-            cont_cb()
-          end
-        end
-      end
-    end
     lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
       ann_key, pending_key)
 
@@ -932,7 +873,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
         'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)',
         ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains)
 
-      if rule.train.learn_type == 'balanced' then
+      if rule.train.learn_mode == 'balanced' then
         local len_bias_check_pred = function(_, l)
           return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
         end