if sel_elt then
-- We have our ANN and that's train vectors, check if we can learn
local ann_key = sel_elt.redis_key
+ local pending_key = neural_common.pending_train_key(rule, set)
+ lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
+ ann_key)
+
+ -- Create continuation closure
+ local redis_len_cb_gen = function(cont_cb, what, is_final)
+ return function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config,
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+ elseif data and type(data) == 'number' or type(data) == 'string' then
+ local ntrains = tonumber(data) or 0
+ lens[what] = ntrains
+ if is_final then
+ -- Ensure that we have the following:
+ -- one class has reached max_trains
+ -- other class(es) are at least as full as classes_bias
+ -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
+ -- one class must have 10 or more trains whilst another should have
+ -- at least (10 * (1 - 0.25)) = 8 trains
+
+ local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
+ local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
+
+ if rule.train.learn_mode == 'balanced' then
+ local len_bias_check_pred = function(_, l)
+ return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+ end
+ if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+ lua_util.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
+ else
+ -- Probabilistic mode, just ensure that at least one vector is okay
+ if min_len > 0 and max_len >= rule.train.max_trains then
+ lua_util.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
+ end
+ else
+ lua_util.debugm(N, rspamd_config,
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
+ cont_cb()
+ end
+ end
+ end
+ end
+ lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
+ ann_key, pending_key)
local function initiate_train()
rspamd_logger.infox(rspamd_config,