LUA_FUNCTION_DEF (task, get_timeval);
LUA_FUNCTION_DEF (task, get_metric_score);
LUA_FUNCTION_DEF (task, get_metric_action);
-LUA_FUNCTION_DEF (task, learn_statfile);
+LUA_FUNCTION_DEF (task, learn);
static const struct luaL_reg tasklib_f[] = {
LUA_INTERFACE_DEF (task, create_empty),
LUA_INTERFACE_DEF (task, get_timeval),
LUA_INTERFACE_DEF (task, get_metric_score),
LUA_INTERFACE_DEF (task, get_metric_action),
- LUA_INTERFACE_DEF (task, learn_statfile),
+ LUA_INTERFACE_DEF (task, learn),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
static gint
-lua_task_learn_statfile (lua_State *L)
+lua_task_learn (lua_State *L)
{
struct rspamd_task *task = lua_check_task (L);
- const gchar *symbol;
+ gboolean is_spam = FALSE;
+ const gchar *clname;
struct rspamd_classifier_config *cl;
- GTree *tokens;
- struct rspamd_statfile_config *st;
- stat_file_t *statfile;
- struct classifier_ctx *ctx;
+ GError *err = NULL;
+ int ret = 1;
- symbol = luaL_checkstring (L, 2);
+ is_spam = lua_toboolean(L, 2);
+ if (lua_gettop (L) > 2) {
+ clname = luaL_checkstring (L, 3);
+ }
+ else {
+ clname = "bayes";
+ }
- if (task && symbol) {
- cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
- if (cl == NULL) {
- msg_warn ("classifier for symbol %s is not found", symbol);
- lua_pushboolean (L, FALSE);
- return 1;
- }
- ctx = cl->classifier->init_func (task->task_pool, cl);
- if ((tokens =
- g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
- msg_warn ("no tokens found learn failed!");
+ cl = rspamd_config_find_classifier (task->cfg, clname);
+
+ if (cl == NULL) {
+ msg_warn ("classifier %s is not found", clname);
+ lua_pushboolean (L, FALSE);
+ lua_pushstring (L, "classifier not found");
+ ret = 2;
+ }
+ else {
+ if (!learn_task_spam (cl, task, is_spam, &err)) {
lua_pushboolean (L, FALSE);
- return 1;
+ if (err != NULL) {
+ lua_pushstring (L, err->message);
+ ret = 2;
+ }
}
- statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool,
- ctx->cfg,
- symbol,
- &st,
- TRUE);
-
- if (statfile == NULL) {
- msg_warn ("opening statfile failed!");
- lua_pushboolean (L, FALSE);
- return 1;
+ else {
+ lua_pushboolean (L, TRUE);
}
-
- cl->classifier->learn_func (ctx,
- task->worker->srv->statfile_pool,
- symbol,
- tokens,
- TRUE,
- NULL,
- 1.,
- NULL);
- maybe_write_binlog (ctx->cfg, st, statfile, tokens);
- lua_pushboolean (L, TRUE);
}
- return 1;
+ return ret;
}
static gint