int fl = 0; \
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
- (n)->ext_flag = fl; \
+ (n)->ext_flag |= fl; \
}while(0)
/***
return 1;
}
+struct rspamd_kann_train_cbdata {
+ lua_State *L;
+ kann_t *k;
+ gint cbref;
+};
+
+static void
+lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
+{
+ struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
+
+ if (cbd->cbref != -1) {
+ gint err_idx;
+ lua_State *L = cbd->L;
+
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
+ lua_pushinteger (L, iter);
+ lua_pushnumber (L, train_cost);
+ lua_pushnumber (L, val_cost);
+
+ if (lua_pcall (L, 3, 0, err_idx) != 0) {
+ msg_err ("cannot run lua train callback: %s",
+ lua_tostring (L, -1));
+ }
+
+ lua_settop (L, err_idx - 1);
+ }
+}
+
+#define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
+
static int
lua_kann_train1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
- g_assert_not_reached (); /* TODO: implement */
+ /* Default train params */
+ double lr = 0.001;
+ gint64 mini_size = 64;
+ gint64 max_epoch = 25;
+ gint64 max_drop_streak = 10;
+ double frac_val = 0.1;
+ gint cbref = -1;
+
+ if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
+ int n = rspamd_lua_table_size (L, 2);
+ int n_in = kann_dim_in (k);
+ int n_out = kann_dim_out (k);
+
+ if (n_in <= 0) {
+ return luaL_error (L, "invalid inputs count: %d", n_in);
+ }
+
+ if (n_out <= 0) {
+ return luaL_error (L, "invalid outputs count: %d", n_in);
+ }
+
+ if (n != rspamd_lua_table_size (L, 3) || n == 0) {
+ return luaL_error (L, "invalid dimensions: outputs size must be "
+ "equal to inputs and non zero");
+ }
+
+ if (lua_istable (L, 4)) {
+ GError *err = NULL;
+
+ if (!rspamd_lua_parse_table_arguments (L, 4, &err,
+ RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
+ "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
+ &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+ n = luaL_error (L, "invalid params: %s",
+ err ? err->message : "unknown error");
+ g_error_free (err);
+
+ return n;
+ }
+ }
+
+ float **x, **y;
+
+ /* Fill vectors */
+ x = (float **)g_malloc (sizeof (float *) * n);
+ y = (float **)g_malloc (sizeof (float *) * n);
+
+ for (int s = 0; s < n; s ++) {
+ /* Inputs */
+ lua_rawgeti (L, 2, s + 1);
+ x[s] = (float *)g_malloc (sizeof (float) * n_in);
+
+ if (rspamd_lua_table_size (L, -1) != n_in) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int)rspamd_lua_table_size (L, -1),
+ n_in);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_in; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ x[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1);
+
+ /* Outputs */
+ y[s] = (float *)g_malloc (sizeof (float) * n_out);
+ lua_rawgeti (L, 3, s + 1);
+
+ if (rspamd_lua_table_size (L, -1) != n_out) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad output dimension %d; "
+ "%d expected",
+ s + 1,
+ (int)rspamd_lua_table_size (L, -1),
+ n_out);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_out; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ y[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1);
+ }
+
+ struct rspamd_kann_train_cbdata cbd;
+
+ cbd.cbref = cbref;
+ cbd.k = k;
+ cbd.L = L;
+
+ int niters = kann_train_fnn1 (k, lr,
+ mini_size, max_epoch, max_drop_streak,
+ frac_val, n, x, y, lua_kann_train_cb, &cbd);
+
+ lua_pushinteger (L, niters);
+
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+ }
+ else {
+ return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
+ " optional params are expected");
+ }
+
+ return 1;
}
static int
gsize vec_len = rspamd_lua_table_size (L, 2);
float *vec = (float *)g_malloc (sizeof (float) * vec_len);
int i_out;
+ int n_in = kann_dim_in (k);
+
+ if (n_in <= 0) {
+ return luaL_error (L, "invalid inputs count: %d", n_in);
+ }
+
+ if (n_in != vec_len) {
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int)vec_len, n_in);
+ }
for (gsize i = 0; i < vec_len; i ++) {
lua_rawgeti (L, 2, i + 1);