From: Vsevolod Stakhov Date: Thu, 14 Aug 2014 12:18:31 +0000 (+0100) Subject: Allow learning from lua_task. X-Git-Tag: 0.7.0~181 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=11aaf8a93a5f147522f7e6238418af38a8517d3a;p=thirdparty%2Frspamd.git Allow learning from lua_task. --- diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 1b62ee83fd..fd8942ce64 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -91,7 +91,7 @@ LUA_FUNCTION_DEF (task, get_message_id); LUA_FUNCTION_DEF (task, get_timeval); LUA_FUNCTION_DEF (task, get_metric_score); LUA_FUNCTION_DEF (task, get_metric_action); -LUA_FUNCTION_DEF (task, learn_statfile); +LUA_FUNCTION_DEF (task, learn); static const struct luaL_reg tasklib_f[] = { LUA_INTERFACE_DEF (task, create_empty), @@ -142,7 +142,7 @@ static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, get_timeval), LUA_INTERFACE_DEF (task, get_metric_score), LUA_INTERFACE_DEF (task, get_metric_action), - LUA_INTERFACE_DEF (task, learn_statfile), + LUA_INTERFACE_DEF (task, learn), {"__tostring", lua_class_tostring}, {NULL, NULL} }; @@ -1295,57 +1295,45 @@ lua_task_get_timeval (lua_State *L) static gint -lua_task_learn_statfile (lua_State *L) +lua_task_learn (lua_State *L) { struct rspamd_task *task = lua_check_task (L); - const gchar *symbol; + gboolean is_spam = FALSE; + const gchar *clname; struct rspamd_classifier_config *cl; - GTree *tokens; - struct rspamd_statfile_config *st; - stat_file_t *statfile; - struct classifier_ctx *ctx; + GError *err = NULL; + int ret = 1; - symbol = luaL_checkstring (L, 2); + is_spam = lua_toboolean(L, 2); + if (lua_gettop (L) > 2) { + clname = luaL_checkstring (L, 3); + } + else { + clname = "bayes"; + } - if (task && symbol) { - cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol); - if (cl == NULL) { - msg_warn ("classifier for symbol %s is not found", symbol); - lua_pushboolean (L, FALSE); - return 1; - } - ctx = cl->classifier->init_func (task->task_pool, cl); - if ((tokens = - g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) { - msg_warn ("no tokens found learn failed!"); + cl = rspamd_config_find_classifier (task->cfg, clname); + + if (cl == NULL) { + msg_warn ("classifier %s is not found", clname); + lua_pushboolean (L, FALSE); + lua_pushstring (L, "classifier not found"); + ret = 2; + } + else { + if (!learn_task_spam (cl, task, is_spam, &err)) { lua_pushboolean (L, FALSE); - return 1; + if (err != NULL) { + lua_pushstring (L, err->message); + ret = 2; + } } - statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool, - ctx->cfg, - symbol, - &st, - TRUE); - - if (statfile == NULL) { - msg_warn ("opening statfile failed!"); - lua_pushboolean (L, FALSE); - return 1; + else { + lua_pushboolean (L, TRUE); } - - cl->classifier->learn_func (ctx, - task->worker->srv->statfile_pool, - symbol, - tokens, - TRUE, - NULL, - 1., - NULL); - maybe_write_binlog (ctx->cfg, st, statfile, tokens); - lua_pushboolean (L, TRUE); } - return 1; + return ret; } static gint