]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Implement l1/l2 regularization against the current weights
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 9 Mar 2018 17:05:03 +0000 (17:05 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 9 Mar 2018 17:05:49 +0000 (17:05 +0000)
lualib/rspamadm/rescore.lua

index fb1428694828d98d2bf75a7543afcbbd4311afd8..c8348caa3ff8dadeec3e08d8ac6b021d7895f08a 100644 (file)
@@ -210,7 +210,7 @@ end
 
 -- training function
 local function train(dataset, opt, model, criterion, epoch,
-                     all_symbols, spam_threshold)
+                     all_symbols, spam_threshold, initial_weights)
   -- epoch tracker
   epoch = epoch or 1
 
@@ -270,16 +270,18 @@ local function train(dataset, opt, model, criterion, epoch,
       -- penalties (L1 and L2):
       local l1 = tonumber(opt.l1) or 0
       local l2 = tonumber(opt.l2) or 0
+
       if l1 ~= 0 or l2 ~= 0 then
         -- locals:
         local norm,sign= torch.norm,torch.sign
 
+        local diff = parameters - initial_weights
         -- Loss:
-        f = f + l1 * norm(parameters,1)
-        f = f + l2 * norm(parameters,2)^2/2
+        f = f + l1 * norm(diff,1)
+        f = f + l2 * norm(diff,2)^2/2
 
         -- Gradients:
-        gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) )
+        gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) )
       end
 
       -- update confusion
@@ -492,10 +494,12 @@ return function (args, cfg)
   for _,lr in ipairs(learning_rates) do
     for _,wd in ipairs(penalty_weights) do
       linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+      local initial_weights = linear_module.weight[1]:clone()
       opts.learning_rate = lr
       opts.weight_decay = wd
       for i=1,tonumber(opts.iters) do
-        train(dataset, opts, perceptron, criterion, i, all_symbols, threshold)
+        train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
+            initial_weights)
       end
 
       local fscore = calculate_fscore_from_weights(cv_logs,