]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Implement statistics relearning.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 18 Feb 2015 15:06:41 +0000 (15:06 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 18 Feb 2015 15:06:41 +0000 (15:06 +0000)
src/libstat/backends/backends.h
src/libstat/backends/mmaped_file.c
src/libstat/classifiers/bayes.c
src/libstat/stat_config.c
src/libstat/stat_process.c

index c7c4210fbc80d45ac77f8b56d16c66726654998b..f8a2af72c074a63735adc7a162c837cc66e10bcd 100644 (file)
@@ -49,6 +49,7 @@ struct rspamd_stat_backend {
                        struct rspamd_token_result *res, gpointer ctx);
        gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
        gulong (*inc_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
+       gulong (*dec_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
        ucl_object_t* (*get_stat)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
        gpointer ctx;
 };
@@ -66,6 +67,8 @@ gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
                gpointer ctx);
 gulong rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime,
                gpointer ctx);
+gulong rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime,
+               gpointer ctx);
 ucl_object_t * rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime,
                gpointer ctx);
 
index 0fb386f61c7dcfde916a61247b190f5de550b79f..02ea17c288d993215ac455d09ba7ffa6d28320e0 100644 (file)
@@ -290,6 +290,23 @@ rspamd_mmaped_file_inc_revision (rspamd_mmaped_file_t *file)
        return TRUE;
 }
 
+gboolean
+rspamd_mmaped_file_dec_revision (rspamd_mmaped_file_t *file)
+{
+       struct stat_file_header *header;
+
+       if (file == NULL || file->map == NULL) {
+               return FALSE;
+       }
+
+       header = (struct stat_file_header *)file->map;
+
+       header->revision--;
+
+       return TRUE;
+}
+
+
 gboolean
 rspamd_mmaped_file_get_revision (rspamd_mmaped_file_t *file, guint64 *rev, time_t *time)
 {
@@ -939,11 +956,7 @@ rspamd_mmaped_file_learn_token (rspamd_token_t *tok,
        memcpy (&h2, tok->data + sizeof (h1), sizeof (h2));
        rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value);
 
-       if (res->value > 0.0) {
-               return TRUE;
-       }
-
-       return FALSE;
+       return TRUE;
 }
 
 gulong
@@ -977,6 +990,23 @@ rspamd_mmaped_file_inc_learns (struct rspamd_statfile_runtime *runtime,
        return rev;
 }
 
+gulong
+rspamd_mmaped_file_dec_learns (struct rspamd_statfile_runtime *runtime,
+               gpointer ctx)
+{
+       rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *)runtime;
+       guint64 rev = 0;
+       time_t t;
+
+       if (mf != NULL) {
+               rspamd_mmaped_file_dec_revision (mf);
+               rspamd_mmaped_file_get_revision (mf, &rev, &t);
+       }
+
+       return rev;
+}
+
+
 ucl_object_t *
 rspamd_mmaped_file_get_stat (struct rspamd_statfile_runtime *runtime,
                gpointer ctx)
index be6c6f5452942f9295196126c03895cf73011595..7932ceb9e45d49d52a63c63a904fe9616f7ad94a 100644 (file)
@@ -221,6 +221,10 @@ bayes_learn_spam_callback (gpointer key, gpointer value, gpointer data)
                if (res->st_runtime->st->is_spam) {
                        res->value ++;
                }
+               else if (res->value > 0) {
+                       /* Unlearning */
+                       res->value --;
+               }
        }
 
        return FALSE;
@@ -241,6 +245,9 @@ bayes_learn_ham_callback (gpointer key, gpointer value, gpointer data)
                if (!res->st_runtime->st->is_spam) {
                        res->value ++;
                }
+               else if (res->value > 0) {
+                       res->value --;
+               }
        }
 
        return FALSE;
index b8ad6ec30ca9f0059efee47c71edd569dadf594b..17b5c54f5862b48ef87d7493a46afa74c399a80b 100644 (file)
@@ -53,6 +53,7 @@ static struct rspamd_stat_backend stat_backends[] = {
                .learn_token = rspamd_mmaped_file_learn_token,
                .total_learns = rspamd_mmaped_file_total_learns,
                .inc_learns = rspamd_mmaped_file_inc_learns,
+               .dec_learns = rspamd_mmaped_file_dec_learns,
                .get_stat = rspamd_mmaped_file_get_stat
        }
 };
index 8b2a1942919c7f2597ae3702d933b3d59fd860cb..1ce439c51168995025d1bd7ae77c668b459c6dbc 100644 (file)
 #include "lua/lua_common.h"
 #include <utlist.h>
 
+#define RSPAMD_CLASSIFY_OP 0
+#define RSPAMD_LEARN_OP 1
+#define RSPAMD_UNLEARN_OP 2
+
 struct preprocess_cb_data {
        struct rspamd_task *task;
        GList *classifier_runtimes;
        struct rspamd_tokenizer_runtime *tok;
        guint results_count;
        gboolean unlearn;
+       gboolean spam;
 };
 
 static struct rspamd_tokenizer_runtime *
@@ -135,7 +140,7 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
 static GList*
 rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
-               lua_State *L, gboolean learn, gboolean spam, GError **err)
+               lua_State *L, gint op, gboolean spam, GError **err)
 {
        struct rspamd_classifier_config *clcf;
        struct rspamd_statfile_config *stcf;
@@ -186,7 +191,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                        stcf = (struct rspamd_statfile_config *)curst->data;
 
                        /* On learning skip statfiles that do not belong to class */
-                       if (learn && (spam != stcf->is_spam)) {
+                       if (op == RSPAMD_LEARN_OP && (spam != stcf->is_spam)) {
                                curst = g_list_next (curst);
                                continue;
                        }
@@ -199,7 +204,8 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                                continue;
                        }
 
-                       backend_runtime = bk->runtime (stcf, learn, bk->ctx);
+                       backend_runtime = bk->runtime (stcf, op != RSPAMD_CLASSIFY_OP,
+                                       bk->ctx);
 
                        st_runtime = rspamd_mempool_alloc0 (task->task_pool,
                                        sizeof (*st_runtime));
@@ -354,7 +360,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
 
        /* Initialize classifiers and statfiles runtime */
        if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
-                       FALSE, FALSE, err)) == NULL) {
+                       RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) {
                return RSPAMD_STAT_PROCESS_ERROR;
        }
 
@@ -407,11 +413,12 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
                        continue;
                }
 
-               res = &g_array_index (t->results, struct rspamd_token_result, i);
 
-               curst = res->cl_runtime->st_runtime;
+
+               curst = cl_runtime->st_runtime;
 
                while (curst) {
+                       res = &g_array_index (t->results, struct rspamd_token_result, i);
                        st_runtime = (struct rspamd_statfile_runtime *)curst->data;
 
                        if (st_runtime->backend->learn_token (t, res,
@@ -432,6 +439,7 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
                        i ++;
                        curst = g_list_next (curst);
                }
+
                cur = g_list_next (cur);
        }
 
@@ -507,7 +515,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
 
        /* Initialize classifiers and statfiles runtime */
        if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L,
-                       TRUE, spam, err)) == NULL) {
+                       unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) {
                return RSPAMD_STAT_PROCESS_ERROR;
        }
 
@@ -530,6 +538,7 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
                                        cbdata.task = task;
                                        cbdata.tok = cl_run->tok;
                                        cbdata.unlearn = unlearn;
+                                       cbdata.spam = spam;
                                        g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
                                                        &cbdata);
 
@@ -538,11 +547,18 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
                                        while (curst) {
                                                st_run = (struct rspamd_statfile_runtime *)curst->data;
 
-                                               nrev = st_run->backend->inc_learns (st_run->backend_runtime,
+                                               if (unlearn && spam != st_run->st->is_spam) {
+                                                       nrev = st_run->backend->dec_learns (st_run->backend_runtime,
+                                                                       st_run->backend->ctx);
+                                                       msg_debug ("unlearned %s, new revision: %ul",
+                                                                       st_run->st->symbol, nrev);
+                                               }
+                                               else {
+                                                       nrev = st_run->backend->inc_learns (st_run->backend_runtime,
                                                                st_run->backend->ctx);
-
-                                               msg_debug ("learned %s, new revision: %ul",
+                                                       msg_debug ("learned %s, new revision: %ul",
                                                                st_run->st->symbol, nrev);
+                                               }
 
                                                curst = g_list_next (curst);
                                        }