From: Vsevolod Stakhov Date: Mon, 12 Sep 2016 15:15:44 +0000 (+0100) Subject: [Feature] Fann scores now uses metadata from a message X-Git-Tag: 1.4.0~438 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f17e1822fac0c2a66d6aa3d4df4eda6cd5a2af91;p=thirdparty%2Frspamd.git [Feature] Fann scores now uses metadata from a message By introducing of extra data, it is now possible to train ANN with metadata of messages improving quality of filtering. --- diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 3e4ec2fc5a..b53fed5553 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -25,6 +25,7 @@ local fann_symbol_ham = 'FANN_HAM' require "fun" () local ucl = require "ucl" +local module_log_id = 0x100 -- Module vars -- ANNs indexed by settings id local data = { @@ -34,28 +35,225 @@ local data = { epoch = 0, } } + local fann_file local max_trains = 1000 local max_epoch = 100 local use_settings = false -local opts = rspamd_config:get_all_opt("fann_scores") -if not (opts and type(opts) == 'table') then - rspamd_logger.infox(rspamd_config, 'Module is unconfigured') - return + + +-- Metafunctions +local function fann_size_function(task) + local sizes = { + 100, + 200, + 500, + 1000, + 2000, + 4000, + 10000, + 20000, + 30000, + 100000, + 200000, + 400000, + 800000, + 1000000, + 2000000, + 8000000, + } + + local size = task:get_size() + for i = 1,#sizes do + if sizes[i] >= size then + return {i / #sizes} + end + end + + return {0} +end + +local function fann_images_function(task) + local images = task:get_images() + local ntotal = 0 + local njpg = 0 + local npng = 0 + local nlarge = 0 + local nsmall = 0 + + if images then + for _,img in ipairs(images) do + if img:get_type() == 'png' then + npng = npng + 1 + elseif img:get_type() == 'jpeg' then + njpg = njpg + 1 + end + + local w = img:get_width() + local h = img:get_height() + + if w > 0 and h > 0 then + if w + h > 256 then + nlarge = nlarge + 1 + else + nsmall = nsmall + 1 + end + end + + ntotal = ntotal + 1 + end + end + + return {ntotal,njpg,npng,nlarge,nsmall} +end + +local function fann_nparts_function(task) + local nattachments = 0 + local ntextparts = 0 + + local tp = task:get_text_parts() + if tp then + ntextparts = #tp + end + + local parts = task:get_parts() + + if parts then + for _,p in ipairs(parts) do + if p:get_filename() then + nattachments = nattachments + 1 + end + end + end + + return {ntextparts, nattachments} +end + +local function fann_encoding_function(task) + local nutf = 0 + local nother = 0 + + local tp = task:get_text_parts() + if tp then + for _,p in ipairs(tp) do + if p:is_utf() then + nutf = nutf + 1 + else + nother = nother + 1 + end + end + end + + return {nutf, nother} +end + +local function fann_recipients_function(task) + local nmime = 0 + local nsmtp = 0 + + if task:has_recipients('mime') then + nmime = #(task:get_recipients('mime')) + end + if task:has_recipients('smtp') then + nsmtp = #(task:get_recipients('smtp')) + end + + return {nmime,nsmtp} +end + +local function fann_received_function(task) + return {#(task:get_received_headers())} end -local function symbols_to_fann_vector(syms) +local function fann_urls_function(task) + if task:has_urls() then + return {#(task:get_urls())} + end + + return {0} +end + +local function fann_attachments_function(task) +end + +local metafunctions = { + { + cb = fann_size_function, + ninputs = 1, + }, + { + cb = fann_images_function, + ninputs = 5, + -- 1 - number of images, + -- 2 - number of png images, + -- 3 - number of jpeg images + -- 4 - number of large images (> 128 x 128) + -- 5 - number of small images (< 128 x 128) + }, + { + cb = fann_nparts_function, + ninputs = 2, + -- 1 - number of text parts + -- 2 - number of attachments + }, + { + cb = fann_encoding_function, + ninputs = 2, + -- 1 - number of utf parts + -- 2 - number of non-utf parts + }, + { + cb = fann_recipients_function, + ninputs = 2, + -- 1 - number of mime rcpt + -- 2 - number of smtp rcpt + }, + { + cb = fann_received_function, + ninputs = 1, + }, + { + cb = fann_urls_function, + ninputs = 1, + }, +} + +local function gen_metatokens(task) + local metatokens = {} + for _,mt in ipairs(metafunctions) do + local ct = mt.cb(task) + + for _,tok in ipairs(ct) do + table.insert(metatokens, tok) + end + end + + rspamd_logger.errx(task, "tokens: %s", metatokens) + + return metatokens +end + +local function count_metatokens() + local total = 0 + for _,mt in ipairs(metafunctions) do + total = total + mt.ninputs + end + + return total +end + +local function symbols_to_fann_vector(syms, scores) local learn_data = {} local matched_symbols = {} local n = rspamd_config:get_symbols_count() - each(function(s) - matched_symbols[s + 1] = 1 - end, syms) + each(function(s, score) + matched_symbols[s + 1] = score + end, zip(syms, scores)) for i=1,n do if matched_symbols[i] then - learn_data[i] = 1 + learn_data[i] = math.abs(matched_symbols[i]) else learn_data[i] = 0 end @@ -85,7 +283,7 @@ local function load_fann(id) rspamd_util.unlock_file(fd) -- closes fd if data[id].fann then - local n = rspamd_config:get_symbols_count() + local n = rspamd_config:get_symbols_count() + count_metatokens() if n ~= data[id].fann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. @@ -115,7 +313,7 @@ end local function check_fann(id) if data[id].fann then - local n = rspamd_config:get_symbols_count() + local n = rspamd_config:get_symbols_count() + count_metatokens if n ~= data[id].fann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. @@ -151,8 +349,13 @@ local function fann_scores_filter(task) check_fann(id) if data[id].fann then - local symbols = task:get_symbols_numeric() - local fann_data = symbols_to_fann_vector(symbols) + local symbols,scores = task:get_symbols_numeric() + local fann_data = symbols_to_fann_vector(symbols, scores) + local mt = gen_metatokens(task) + + for _,tok in ipairs(mt) do + table.insert(fann_data, tok) + end local out = data[id].fann:test(fann_data) local result = rspamd_util.tanh(2 * (out[1] - 0.5)) @@ -177,8 +380,8 @@ local function create_train_fann(n, id) data[id].epoch = 0 end -local function fann_train_callback(score, required_score,results, cf, id, opts) - local n = cf:get_symbols_count() +local function fann_train_callback(score, required_score, results, cf, id, opts, extra) + local n = cf:get_symbols_count() + count_metatokens() local fname = gen_fann_file(id) if not data[id].fann_train then @@ -240,8 +443,11 @@ local function fann_train_callback(score, required_score,results, cf, id, opts) if learn_spam or learn_ham then local learn_data = symbols_to_fann_vector( - map(function(r) return r[1] end, results) + map(function(r) return r[1] end, results), + map(function(r) return r[2] end, results) ) + -- Add filtered meta tokens + each(function(e) table.insert(learn_data, e) end, extra) if learn_spam then data[id].fann_train:train(learn_data, {1.0}) @@ -253,6 +459,14 @@ local function fann_train_callback(score, required_score,results, cf, id, opts) end end +-- Initialization part + +local opts = rspamd_config:get_all_opt("fann_scores") +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + return +end + if not rspamd_fann.is_enabled() then rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. 'module is eventually disabled') @@ -294,15 +508,26 @@ else max_epoch = opts['train']['max_epoch'] end cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, id) + function(score, req_score, results, cf, id, extra) + -- map (snd x) (filter (fst x == module_id) extra) + local extra_fann = map(function(e) return e[2] end, + filter(function(e) return e[1] == module_log_id end, extra)) if use_settings then fann_train_callback(score, req_score, results, cf, - tostring(id), opts['train']) + tostring(id), opts['train'], extra_fann) else - fann_train_callback(score, req_score, results, cf, '0', opts['train']) + fann_train_callback(score, req_score, results, cf, '0', + opts['train'], extra_fann) end end) end) + rspamd_plugins["fann_score"] = { + log_callback = function(task) + return totable(map( + function(tok) return {module_log_id, tok} end, + gen_metatokens(task))) + end + } end end end