]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add training support to kann
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 1 Jul 2019 12:30:09 +0000 (13:30 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 1 Jul 2019 12:30:09 +0000 (13:30 +0100)
contrib/kann/kann.c
src/lua/lua_kann.c

index 3fbf139cce17871f3e238ca38804e20938bd1a71..43227bdc66ea4829fceef8d1fee79e87c824a517 100644 (file)
@@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return
 kad_node_t *kann_layer_input(int n1)
 {
        kad_node_t *t;
-       t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN;
+       t = kad_feed(2, 1, n1);
+       t->ext_flag |= KANN_F_IN;
        return t;
 }
 
@@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
        assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
        t = kann_layer_dense(t, n_out);
        truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
+
        if (cost_type == KANN_C_MSE) {
                cost = kad_mse(t, truth);
        } else if (cost_type == KANN_C_CEB) {
@@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
                t = kad_softmax(t);
                cost = kad_ce_multi(t, truth);
        }
-       t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST;
+       else {
+               assert (0);
+       }
+
+       t->ext_flag |= KANN_F_OUT;
+       cost->ext_flag |= KANN_F_COST;
+
        return cost;
 }
 
index a1b31014d60665f969777f5e2dd8e940803380f8..609f05539bc314c83b7f8274aba134216cb2e232 100644 (file)
@@ -295,7 +295,7 @@ void luaopen_kann (lua_State *L)
        int fl = 0; \
        if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
        else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
-       (n)->ext_flag = fl; \
+       (n)->ext_flag |= fl; \
 }while(0)
 
 /***
@@ -984,12 +984,168 @@ lua_kann_load (lua_State *L)
        return 1;
 }
 
+struct rspamd_kann_train_cbdata {
+       lua_State *L;
+       kann_t *k;
+       gint cbref;
+};
+
+static void
+lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
+{
+       struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
+
+       if (cbd->cbref != -1) {
+               gint err_idx;
+               lua_State *L = cbd->L;
+
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               err_idx = lua_gettop (L);
+
+               lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
+               lua_pushinteger (L, iter);
+               lua_pushnumber (L, train_cost);
+               lua_pushnumber (L, val_cost);
+
+               if (lua_pcall (L, 3, 0, err_idx) != 0) {
+                       msg_err ("cannot run lua train callback: %s",
+                                       lua_tostring (L, -1));
+               }
+
+               lua_settop (L, err_idx - 1);
+       }
+}
+
+#define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
+
 static int
 lua_kann_train1 (lua_State *L)
 {
        kann_t *k = lua_check_kann (L, 1);
 
-       g_assert_not_reached (); /* TODO: implement */
+       /* Default train params */
+       double lr = 0.001;
+       gint64 mini_size = 64;
+       gint64 max_epoch = 25;
+       gint64 max_drop_streak = 10;
+       double frac_val = 0.1;
+       gint cbref = -1;
+
+       if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
+               int n = rspamd_lua_table_size (L, 2);
+               int n_in = kann_dim_in (k);
+               int n_out = kann_dim_out (k);
+
+               if (n_in <= 0) {
+                       return luaL_error (L, "invalid inputs count: %d", n_in);
+               }
+
+               if (n_out <= 0) {
+                       return luaL_error (L, "invalid outputs count: %d", n_in);
+               }
+
+               if (n != rspamd_lua_table_size (L, 3) || n == 0) {
+                       return luaL_error (L, "invalid dimensions: outputs size must be "
+                                                "equal to inputs and non zero");
+               }
+
+               if (lua_istable (L, 4)) {
+                       GError *err = NULL;
+
+                       if (!rspamd_lua_parse_table_arguments (L, 4, &err,
+                                       RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
+                                       "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
+                                       &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+                               n = luaL_error (L, "invalid params: %s",
+                                               err ? err->message : "unknown error");
+                               g_error_free (err);
+
+                               return n;
+                       }
+               }
+
+               float **x, **y;
+
+               /* Fill vectors */
+               x = (float **)g_malloc (sizeof (float *) * n);
+               y = (float **)g_malloc (sizeof (float *) * n);
+
+               for (int s = 0; s < n; s ++) {
+                       /* Inputs */
+                       lua_rawgeti (L, 2, s + 1);
+                       x[s] = (float *)g_malloc (sizeof (float) * n_in);
+
+                       if (rspamd_lua_table_size (L, -1) != n_in) {
+                               FREE_VEC (x, n);
+                               FREE_VEC (y, n);
+
+                               n = luaL_error (L, "invalid params at pos %d: "
+                                          "bad input dimension %d; %d expected",
+                                               s + 1,
+                                               (int)rspamd_lua_table_size (L, -1),
+                                               n_in);
+
+                               return n;
+                       }
+
+                       for (int i = 0; i < n_in; i ++) {
+                               lua_rawgeti (L, -1, i + 1);
+                               x[s][i] = lua_tonumber (L, -1);
+
+                               lua_pop (L, 1);
+                       }
+
+                       lua_pop (L, 1);
+
+                       /* Outputs */
+                       y[s] = (float *)g_malloc (sizeof (float) * n_out);
+                       lua_rawgeti (L, 3, s + 1);
+
+                       if (rspamd_lua_table_size (L, -1) != n_out) {
+                               FREE_VEC (x, n);
+                               FREE_VEC (y, n);
+
+                               n = luaL_error (L, "invalid params at pos %d: "
+                                          "bad output dimension %d; "
+                                          "%d expected",
+                                               s + 1,
+                                               (int)rspamd_lua_table_size (L, -1),
+                                               n_out);
+
+                               return n;
+                       }
+
+                       for (int i = 0; i < n_out; i ++) {
+                               lua_rawgeti (L, -1, i + 1);
+                               y[s][i] = lua_tonumber (L, -1);
+
+                               lua_pop (L, 1);
+                       }
+
+                       lua_pop (L, 1);
+               }
+
+               struct rspamd_kann_train_cbdata cbd;
+
+               cbd.cbref = cbref;
+               cbd.k = k;
+               cbd.L = L;
+
+               int niters = kann_train_fnn1 (k, lr,
+                               mini_size, max_epoch, max_drop_streak,
+                               frac_val, n, x, y, lua_kann_train_cb, &cbd);
+
+               lua_pushinteger (L, niters);
+
+               FREE_VEC (x, n);
+               FREE_VEC (y, n);
+       }
+       else {
+               return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
+                                                         " optional params are expected");
+       }
+
+       return 1;
 }
 
 static int
@@ -1001,6 +1157,16 @@ lua_kann_apply1 (lua_State *L)
                gsize vec_len = rspamd_lua_table_size (L, 2);
                float *vec = (float *)g_malloc (sizeof (float) * vec_len);
                int i_out;
+               int n_in = kann_dim_in (k);
+
+               if (n_in <= 0) {
+                       return luaL_error (L, "invalid inputs count: %d", n_in);
+               }
+
+               if (n_in != vec_len) {
+                       return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+                                       (int)vec_len, n_in);
+               }
 
                for (gsize i = 0; i < vec_len; i ++) {
                        lua_rawgeti (L, 2, i + 1);