]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add learning support for lua classifiers
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Oct 2016 17:52:16 +0000 (18:52 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Oct 2016 17:52:16 +0000 (18:52 +0100)
src/libstat/classifiers/lua_classifier.c

index cd9fc3bd1541cec39e5a58ef62e045a7b45ce6fc..a28b58d84f0af75ed5cfb2a1d415339dc7d59e85 100644 (file)
@@ -183,5 +183,53 @@ lua_classifier_learn_spam (struct rspamd_classifier *cl,
                gboolean unlearn,
                GError **err)
 {
+       struct rspamd_lua_classifier_ctx *ctx;
+       struct rspamd_task **ptask;
+       struct rspamd_classifier_config **pcfg;
+       lua_State *L;
+       rspamd_token_t *tok;
+       guint i;
+       guint64 v;
+
+       ctx = g_hash_table_lookup (lua_classifiers, cl->subrs->name);
+       g_assert (ctx != NULL);
+       L = task->cfg->lua_state;
+
+       lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->learn_ref);
+       ptask = lua_newuserdata (L, sizeof (*ptask));
+       *ptask = task;
+       rspamd_lua_setclass (L, "rspamd{task}", -1);
+       pcfg = lua_newuserdata (L, sizeof (*pcfg));
+       *pcfg = cl->cfg;
+       rspamd_lua_setclass (L, "rspamd{classifier}", -1);
+
+       lua_createtable (L, tokens->len, 0);
+
+       for (i = 0; i < tokens->len; i ++) {
+               tok = g_ptr_array_index (tokens, i);
+               v = 0;
+               memcpy (&v, tok->data, MIN (sizeof (v), tok->datalen));
+               lua_createtable (L, 3, 0);
+               /* High word, low word, order */
+               lua_pushnumber (L, (guint32)(v >> 32));
+               lua_rawseti (L, -2, 1);
+               lua_pushnumber (L, (guint32)(v));
+               lua_rawseti (L, -2, 2);
+               lua_pushnumber (L, tok->window_idx);
+               lua_rawseti (L, -2, 3);
+               lua_rawseti (L, -2, i + 1);
+       }
+
+       lua_pushboolean (L, is_spam);
+       lua_pushboolean (L, unlearn);
+
+       if (lua_pcall (L, 5, 0, 0) != 0) {
+               msg_err_luacl ("error running learn function for %s: %s", ctx->name,
+                               lua_tostring (L, -1));
+               lua_pop (L, 1);
+
+               return FALSE;
+       }
+
        return TRUE;
 }