local ann_key = sel_elt.redis_key
local pending_key = neural_common.pending_train_key(rule, set)
- lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
-
- -- Create continuation closure
- local redis_len_cb_gen = function(cont_cb, what, is_final)
- return function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
- elseif data and type(data) == 'number' or type(data) == 'string' then
- local ntrains = tonumber(data) or 0
- lens[what] = ntrains
- if is_final then
- -- Ensure that we have the following:
- -- one class has reached max_trains
- -- other class(es) are at least as full as classes_bias
- -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
- -- one class must have 10 or more trains whilst another should have
- -- at least (10 * (1 - 0.25)) = 8 trains
-
- local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
- local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
-
- if rule.train.learn_mode == 'balanced' then
- local len_bias_check_pred = function(_, l)
- return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
- end
- if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
- lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- else
- -- Probabilistic mode, just ensure that at least one vector is okay
- if min_len > 0 and max_len >= rule.train.max_trains then
- lua_util.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- lua_util.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- end
- else
- lua_util.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
- cont_cb()
- end
- end
- end
- end
lua_util.debugm(N, rspamd_config, "check if ANN %s (pending %s) needs to be trained",
ann_key, pending_key)
'final vector count for ANN %s: spam=%s ham=%s (min=%s max=%s required=%s)',
ann_key, lens.spam, lens.ham, min_len, max_len, rule.train.max_trains)
- if rule.train.learn_type == 'balanced' then
+ if rule.train.learn_mode == 'balanced' then
local len_bias_check_pred = function(_, l)
return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
end