]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Allow learning from lua_task.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 14 Aug 2014 12:18:31 +0000 (13:18 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 14 Aug 2014 12:18:31 +0000 (13:18 +0100)
src/lua/lua_task.c

index 1b62ee83fdc404891cfafaecc70d6c3bc6fcead4..fd8942ce64cdd9b0705b56ad581c27688620a43d 100644 (file)
@@ -91,7 +91,7 @@ LUA_FUNCTION_DEF (task, get_message_id);
 LUA_FUNCTION_DEF (task, get_timeval);
 LUA_FUNCTION_DEF (task, get_metric_score);
 LUA_FUNCTION_DEF (task, get_metric_action);
-LUA_FUNCTION_DEF (task, learn_statfile);
+LUA_FUNCTION_DEF (task, learn);
 
 static const struct luaL_reg tasklib_f[] = {
        LUA_INTERFACE_DEF (task, create_empty),
@@ -142,7 +142,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF (task, get_timeval),
        LUA_INTERFACE_DEF (task, get_metric_score),
        LUA_INTERFACE_DEF (task, get_metric_action),
-       LUA_INTERFACE_DEF (task, learn_statfile),
+       LUA_INTERFACE_DEF (task, learn),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
 };
@@ -1295,57 +1295,45 @@ lua_task_get_timeval (lua_State *L)
 
 
 static gint
-lua_task_learn_statfile (lua_State *L)
+lua_task_learn (lua_State *L)
 {
        struct rspamd_task *task = lua_check_task (L);
-       const gchar *symbol;
+       gboolean is_spam = FALSE;
+       const gchar *clname;
        struct rspamd_classifier_config *cl;
-       GTree *tokens;
-       struct rspamd_statfile_config *st;
-       stat_file_t *statfile;
-       struct classifier_ctx *ctx;
+       GError *err = NULL;
+       int ret = 1;
 
-       symbol = luaL_checkstring (L, 2);
+       is_spam = lua_toboolean(L, 2);
+       if (lua_gettop (L) > 2) {
+               clname = luaL_checkstring (L, 3);
+       }
+       else {
+               clname = "bayes";
+       }
 
-       if (task && symbol) {
-               cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
-               if (cl == NULL) {
-                       msg_warn ("classifier for symbol %s is not found", symbol);
-                       lua_pushboolean (L, FALSE);
-                       return 1;
-               }
-               ctx = cl->classifier->init_func (task->task_pool, cl);
-               if ((tokens =
-                       g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
-                       msg_warn ("no tokens found learn failed!");
+       cl = rspamd_config_find_classifier (task->cfg, clname);
+
+       if (cl == NULL) {
+               msg_warn ("classifier %s is not found", clname);
+               lua_pushboolean (L, FALSE);
+               lua_pushstring (L, "classifier not found");
+               ret = 2;
+       }
+       else {
+               if (!learn_task_spam (cl, task, is_spam, &err)) {
                        lua_pushboolean (L, FALSE);
-                       return 1;
+                       if (err != NULL) {
+                               lua_pushstring (L, err->message);
+                               ret = 2;
+                       }
                }
-               statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool,
-                               ctx->cfg,
-                               symbol,
-                               &st,
-                               TRUE);
-
-               if (statfile == NULL) {
-                       msg_warn ("opening statfile failed!");
-                       lua_pushboolean (L, FALSE);
-                       return 1;
+               else {
+                       lua_pushboolean (L, TRUE);
                }
-
-               cl->classifier->learn_func (ctx,
-                       task->worker->srv->statfile_pool,
-                       symbol,
-                       tokens,
-                       TRUE,
-                       NULL,
-                       1.,
-                       NULL);
-               maybe_write_binlog (ctx->cfg, st, statfile, tokens);
-               lua_pushboolean (L, TRUE);
        }
 
-       return 1;
+       return ret;
 }
 
 static gint