From: Vsevolod Stakhov Date: Sun, 7 Jul 2019 18:45:08 +0000 (+0100) Subject: [Minor] Neural: Further fixes X-Git-Tag: 2.0~639 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a1af120934a908292544b7848b3e62da0e8b9030;p=thirdparty%2Frspamd.git [Minor] Neural: Further fixes --- diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index d2c0191e7a..b0f307803f 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -280,7 +280,6 @@ local function new_ann_profile(task, rule, set, version) end lua_redis.redis_make_request(task, - rspamd_config, rule.redis, nil, true, -- is write @@ -347,21 +346,28 @@ local function create_ann(n, nlayers) end -local function ann_train_callback(rule, task, score, required_score, set) +local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train + local learn_spam, learn_ham if train_opts.autotrain then + if verdict == 'passthrough' or verdict == 'uncertain' then + lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', + verdict, score) + end + if train_opts['spam_score'] then learn_spam = score >= train_opts['spam_score'] else - learn_spam = score >= required_score + learn_spam = verdict == 'spam' or verdict == 'junk' end + if train_opts['ham_score'] then learn_ham = score <= train_opts['ham_score'] else - learn_ham = score < 0 + learn_ham = verdict == 'ham' end else -- Train by request header @@ -408,7 +414,7 @@ local function ann_train_callback(rule, task, score, required_score, set) true, -- is write learn_vec_cb, --callback 'LPUSH', -- command - { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments + { set.ann.redis_key .. '_' .. learn_type, str} -- arguments ) else if err then @@ -948,6 +954,7 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback) rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', err) elseif type(data) == 'table' then + lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) end end @@ -1000,12 +1007,12 @@ end local function ann_push_vector(task) if task:has_flag('skip') then return end if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end - local scores = task:get_metric_score() + local verdict,score = lua_util.get_task_verdict(task) for _,rule in pairs(settings.rules) do local sid = task:get_settings_id() or -1 if rule.settings[sid] then - ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid]) + ann_push_task_result(rule, task, verdict, score, rule.settings[sid]) end end @@ -1124,6 +1131,10 @@ local id = rspamd_config:register_symbol({ callback = ann_scores_filter }) +settings = lua_util.override_defaults(settings, module_config) +settings.rules = {} -- Reset unless validated further in the cycle + +-- Check all rules for k,r in pairs(rules) do local rule_elt = lua_util.override_defaults(default_options, r) rule_elt['redis'] = redis_params