]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Allow to specify number of threads for ANN learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 09:04:59 +0000 (10:04 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 09:04:59 +0000 (10:04 +0100)
lualib/lua_nn.lua
src/plugins/lua/fann_redis.lua

index 0e8977d3749cd47662539ccff77d68a8b182a9f7..d0d2d5265d8f6bd0ece43b40faa40de61d7cbe0c 100644 (file)
@@ -22,6 +22,7 @@ local lua_nn_models = {}
 
 if rspamd_config:has_torch() then
   torch = require "torch"
+  torch.setnumthreads(1)
 end
 
 if torch then
index 2751b5d79198e5ea3b68ec116e11f582f3b63c01..ab3da003505e4e7727d04566779d2f511d235b4c 100644 (file)
@@ -47,6 +47,7 @@ local default_options = {
     mse = 0.001,
     autotrain = true,
     train_prob = 1.0,
+    learn_threads = 1,
   },
   use_settings = false,
   per_user = false,
@@ -781,6 +782,9 @@ local function train_fann(rule, _, ev_base, elt, worker)
           dataset.size = function() return #dataset end
 
           local function train_torch()
+            if rule.train.learn_threads > 1 then
+              torch.setnumthreads(rule.train.learn_threads)
+            end
             local criterion = nn.MSECriterion()
             local trainer = nn.StochasticGradient(fanns[elt].fann_train,
               criterion)