From: Vsevolod Stakhov Date: Mon, 1 Jul 2019 12:50:21 +0000 (+0100) Subject: [Test] Add unit test for kann X-Git-Tag: 2.0~686 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f8004be4c94ca214cd399cdb18aa0d085abf0fd7;p=thirdparty%2Frspamd.git [Test] Add unit test for kann --- diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua new file mode 100644 index 0000000000..bb69302035 --- /dev/null +++ b/test/lua/unit/kann.lua @@ -0,0 +1,43 @@ +-- Simple kann test (xor function vs 2 layer MLP) + +context("Kann test", function() + local kann = require "rspamd_kann" + local k + local inputs = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1} + } + + local outputs = { + {0}, + {1}, + {1}, + {0} + } + + local t = kann.layer.input(2) + t = kann.transform.relu(t) + t = kann.transform.tanh(kann.layer.dense(t, 2)); + t = kann.layer.cost(t, 1, kann.cost.mse) + k = kann.new.kann(t) + + local iters = 500 + local niter = k:train1(inputs, outputs, { + lr = 0.01, + max_epoch = iters, + mini_size = 80, + }) + + for i,inp in ipairs(inputs) do + test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]), + function() + local res = math.floor(k:apply1(inp)[1] + 0.5) + assert_equal(outputs[i][1], res, + tostring(outputs[i][1]) .. " but test returned " .. tostring(res)) + end) + end + + +end) \ No newline at end of file