]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
* Fixes to winnow learning
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 5 Aug 2010 17:29:40 +0000 (21:29 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 5 Aug 2010 17:29:40 +0000 (21:29 +0400)
src/classifiers/classifiers.h
src/classifiers/winnow.c
src/controller.c
src/filter.c

index 02192d79550e5756ac01554d1ea215564e0a98c7..f69c1284ce7d33592ccf1bacd755cdb3c489b165 100644 (file)
@@ -24,9 +24,10 @@ struct classify_weight {
 struct classifier {
        char *name;
        struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf);
-       void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
-       void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, 
-                                                       stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier);
+       gboolean (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+       gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
+                                                       const char *symbol, GTree *input, gboolean in_class,
+                                                       double *sum, double multiplier, GError **err);
        GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 };
 
@@ -35,9 +36,9 @@ struct classifier* get_classifier (char *name);
 
 /* Winnow algorithm */
 struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf);
-void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
-void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, 
-                               gboolean in_class, double *sum, double multiplier);
+gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
+                               gboolean in_class, double *sum, double multiplier, GError **err);
 GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
 
 
index 41cb48e893e04dcaf2195232d1ba17ae2499508e..ab155ff8c2b3e52c6fd10890d68c3c976c4566bb 100644 (file)
 
 #define MAX_LEARN_ITERATIONS 100
 
+G_INLINE_FUNC GQuark
+winnow_error_quark (void)
+{
+       return g_quark_from_static_string ("winnow-error-quark");
+}
+
 struct winnow_callback_data {
        statfile_pool_t                *pool;
        struct classifier_ctx          *ctx;
@@ -53,7 +59,8 @@ struct winnow_callback_data {
        stat_file_t                    *learn_file;
        long double                     sum;
        double                          multiplier;
-       int                             count;
+       guint32                         count;
+       guint32                         new_blocks;
        gboolean                        in_class;
        gboolean                        do_demote;
        gboolean                        fresh_run;
@@ -62,6 +69,8 @@ struct winnow_callback_data {
 
 static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION;
 
+
+
 static                          gboolean
 classify_callback (gpointer key, gpointer value, gpointer data)
 {
@@ -73,10 +82,10 @@ classify_callback (gpointer key, gpointer value, gpointer data)
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
        if (fabs (v) > ALPHA) {
                cd->sum += v;
-               cd->in_class++;
        }
        else {
                cd->sum += 1.0;
+               cd->new_blocks ++;
        }
 
        cd->count++;
@@ -100,6 +109,7 @@ learn_callback (gpointer key, gpointer value, gpointer data)
                if (cd->file == cd->learn_file) {
                        statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
                        node->value = c;
+                       cd->new_blocks ++;
                }
        }
        else {
@@ -181,7 +191,7 @@ winnow_init (memory_pool_t * pool, struct classifier_config *cfg)
        return ctx;
 }
 
-void
+gboolean
 winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
 {
        struct winnow_callback_data     data;
@@ -203,7 +213,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE;
                if (nodes < minnodes) {
                        msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes);
-                       return;
+                       return FALSE;
                }
        }
 
@@ -224,6 +234,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                st = cur->data;
                data.sum = 0;
                data.count = 0;
+               data.new_blocks = 0;
                if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
                        if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
                                msg_warn ("cannot open %s, skip it", st->path);
@@ -233,9 +244,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                }
 
                if (data.file != NULL) {
-                       statfile_pool_lock_file (pool, data.file);
                        g_tree_foreach (input, classify_callback, &data);
-                       statfile_pool_unlock_file (pool, data.file);
                }
 
                if (data.count != 0) {
@@ -263,6 +272,8 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                cur = g_list_prepend (NULL, sumbuf);
                insert_result (task, sel->symbol, max, cur);
        }
+
+       return TRUE;
 }
 
 GList *
@@ -306,9 +317,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
                }
 
                if (data.file != NULL) {
-                       statfile_pool_lock_file (pool, data.file);
                        g_tree_foreach (input, classify_callback, &data);
-                       statfile_pool_unlock_file (pool, data.file);
                }
 
                w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
@@ -333,8 +342,9 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
 }
 
 
-void
-winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier)
+gboolean
+winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *symbol,
+               GTree * input, int in_class, double *sum, double multiplier, GError **err)
 {
        struct winnow_callback_data     data = {
                .file = NULL,
@@ -343,10 +353,11 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
        char                           *value;
        int                             nodes, minnodes, iterations = 0;
        struct statfile                *st, *sel_st;
-       stat_file_t                    *sel = NULL;
+       stat_file_t                    *sel = NULL, *to_learn;
        long double                     res = 0., max = 0.;
-       double                          learn_threshold = 1.0;
+       double                          learn_threshold = 0.0;
        GList                          *cur, *to_demote = NULL;
+       gboolean                        force_learn = FALSE;
 
        g_assert (pool != NULL);
        g_assert (ctx != NULL);
@@ -355,7 +366,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
        data.in_class = in_class;
        data.now = time (NULL);
        data.ctx = ctx;
-       data.learn_file = file;
+
 
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
                minnodes = strtol (value, NULL, 10);
@@ -363,70 +374,121 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                if (nodes < minnodes) {
                        msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes);
                        *sum = 0;
-                       return;
+                       g_set_error (err,
+                          winnow_error_quark(),                /* error domain */
+                          1,                                           /* error code */
+                          "message contains too few tokens: %d, while min is %d",
+                          nodes, minnodes);
+                       return FALSE;
                }
        }
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) {
                learn_threshold = strtod (value, NULL);
        }
        
-       if (learn_threshold >= 1.0) {
+       if (learn_threshold <= 1.0 && learn_threshold >= 0) {
                /* Classify message and check target statfile score */
                cur = ctx->cfg->statfiles;
+               while (cur) {
+                       /* Open or create all statfiles inside classifier */
+                       st = cur->data;
+                       if (statfile_pool_is_open (pool, st->path) == NULL) {
+                               if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) {
+                                       msg_warn ("cannot open %s", st->path);
+                                       if (statfile_pool_create (pool, st->path, st->size) == -1) {
+                                               msg_err ("cannot create statfile %s", st->path);
+                                               g_set_error (err,
+                                                               winnow_error_quark(),           /* error domain */
+                                                               1,                                      /* error code */
+                                                               "cannot create statfile: %s",
+                                                               st->path);
+                                               return FALSE;
+                                       }
+                                       if (statfile_pool_open (pool, st->path, st->size, FALSE)) {
+                                               g_set_error (err,
+                                                               winnow_error_quark(),           /* error domain */
+                                                               1,                                      /* error code */
+                                                               "open statfile %s after creation",
+                                                               st->path);
+                                               msg_err ("cannot open statfile %s after creation", st->path);
+                                               return FALSE;
+                                       }
+                               }
+                       }
+                       if (strcmp (st->symbol, symbol) == 0) {
+                               sel_st = st;
+
+                       }
+                       cur = g_list_next (cur);
+               }
+               to_learn = statfile_pool_is_open (pool, sel_st->path);
+               if (to_learn == NULL) {
+                       g_set_error (err,
+                                       winnow_error_quark(),           /* error domain */
+                                       1,                                      /* error code */
+                                       "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
+                                       sel_st->path);
+                       return FALSE;
+               }
                /* Check target statfile */
-               data.file = file;
+               data.file = to_learn;
                data.sum = 0;
                data.count = 0;
-               data.file = file;
-               statfile_pool_lock_file (pool, data.file);
+               data.new_blocks = 0;
                g_tree_foreach (input, classify_callback, &data);
-               statfile_pool_unlock_file (pool, data.file);
                if (data.count > 0) {
                        max = data.sum / (double)data.count;
                }
                else {
                        max = 0;
                }
+               /* If most of blocks are not presented in targeted statfile do forced learn */
+               if ((data.new_blocks > 1 && (double)data.new_blocks / (double)data.count > 0.5) || max < 1 + learn_threshold) {
+                       force_learn = TRUE;
+               }
+               /* Check other statfiles */
                while (cur) {
                        st = cur->data;
                        data.sum = 0;
                        data.count = 0;
                        if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
-                               if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
-                                       msg_warn ("cannot open %s, skip it", st->path);
-                                       cur = g_list_next (cur);
-                                       continue;
-                               }
+                               g_set_error (err,
+                                               winnow_error_quark(),           /* error domain */
+                                               1,                                      /* error code */
+                                               "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
+                                               st->path);
+                               return FALSE;
                        }
-                       statfile_pool_lock_file (pool, data.file);
                        g_tree_foreach (input, classify_callback, &data);
-                       statfile_pool_unlock_file (pool, data.file);
                        if (data.count != 0) {
                                res = data.sum / data.count;
                        }
                        else {
                                res = 0;
                        }
-                       if (file != data.file && res / max > learn_threshold) {
+                       if (to_learn != data.file && res - max > 1 - learn_threshold) {
                                /* Demote tokens in this statfile */
                                to_demote = g_list_prepend (to_demote, data.file);
                        }
-                       else if (file == data.file) {
-                               sel_st = st;
-                       }
                        cur = g_list_next (cur);
                }
        }
        else {
-               msg_err ("learn threshold is less than 1, so cannot do learn, please check your configuration");
-               return;
+               msg_err ("learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration");
+               g_set_error (err,
+                               winnow_error_quark(),           /* error domain */
+                               1,                                      /* error code */
+                               "bad learn_threshold setting: %.2f",
+                               learn_threshold);
+               return FALSE;
        }
        /* If to_demote list is empty this message is already classified correctly */
-       if (max > ALPHA && to_demote == NULL) {
+       if (max > ALPHA && to_demote == NULL && !force_learn) {
                msg_info ("this message is already of class %s with threshold %.2f and weight %.2F",
                                sel_st->symbol, learn_threshold, max);
                goto end;
        }
+       data.learn_file = to_learn;
        do {
                cur = ctx->cfg->statfiles;
                data.fresh_run = TRUE;
@@ -434,12 +496,9 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                        st = cur->data;
                        data.sum = 0;
                        data.count = 0;
+                       data.new_blocks = 0;
                        if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
-                               if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
-                                       msg_warn ("cannot open %s, skip it", st->path);
-                                       cur = g_list_next (cur);
-                                       continue;
-                               }
+                               return FALSE;
                        }
                        if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) {
                                data.do_demote = TRUE;
@@ -470,14 +529,20 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                else {
                        data.multiplier *= WINNOW_PROMOTION;
                }
-       } while ((in_class ? sel != file : sel == file)  && iterations ++ < MAX_LEARN_ITERATIONS);
+       } while ((in_class ? sel != to_learn : sel == to_learn)  && iterations ++ < MAX_LEARN_ITERATIONS);
        
        if (iterations >= MAX_LEARN_ITERATIONS) {
                msg_warn ("learning statfile %s  was not fully successfull: iterations count is limited to %d, final sum is %G", 
-                               file->filename, MAX_LEARN_ITERATIONS, max);
+                               sel_st->symbol, MAX_LEARN_ITERATIONS, max);
+               g_set_error (err,
+                               winnow_error_quark(),           /* error domain */
+                               1,                                      /* error code */
+                               "learning statfile %s  was not fully successfull: iterations count is limited to %d",
+                               sel_st->symbol, MAX_LEARN_ITERATIONS);
+               return FALSE;
        }
        else {
-               msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations + 1, max);
+               msg_info ("learned statfile %s successfully with %d iterations and sum %G", sel_st->symbol, iterations + 1, max);
        }
 
 
@@ -485,4 +550,5 @@ end:
        if (sum) {
                *sum = (double)max;
        }
+       return TRUE;
 }
index 6ef9f7e08f48d57e9d1c98aa628e7bdc7c3b3206..15b14cf8d63a106166a3ba823936c9521c43a523 100644 (file)
@@ -704,6 +704,7 @@ controller_read_socket (f_str_t * in, void *arg)
        struct mime_text_part          *part;
        GList                          *comp_list, *cur = NULL;
        GTree                          *tokens = NULL;
+       GError                         *err = NULL;
        f_str_t                         c;
        double                          sum;
 
@@ -818,26 +819,33 @@ controller_read_socket (f_str_t * in, void *arg)
                        return TRUE;
                }
        
+
+               /* Init classifier */
+               cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
                /* Get or create statfile */
                statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier,
-                                               session->learn_symbol, &st, TRUE);
-               if (statfile == NULL) {
-                       msg_info ("learn failed for message <%s>, no statfile found: %s", task->message_id, session->learn_symbol);
+                               session->learn_symbol, &st, TRUE);
+
+               if (statfile == NULL ||
+                       ! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
+                                                                                                                               session->learn_symbol, tokens, session->in_class, &sum,
+                                                                                                                               session->learn_multiplier, &err)) {
+                       if (err) {
+                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF, err->message);
+                               msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message);
+                               g_error_free (err);
+                       }
+                       else {
+                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF);
+                               msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
+                       }
                        free_task (task, FALSE);
-                       i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, invalid symbol" CRLF);
                        if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
                                return FALSE;
                        }
+                       session->state = STATE_REPLY;
                        return TRUE;
                }
-               
-               /* Init classifier */
-               cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
-               
-               /* XXX: remove this awful legacy */
-               session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, 
-                                                                                                                               statfile, tokens, session->in_class, &sum,
-                                                                                                                               session->learn_multiplier);
                session->worker->srv->stat->messages_learned++;
 
                maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
index 9e2da0c57c38b8a74aa73280f4cd222bd8d8e600..90566ded90ba6692e48000ea68e4c2d0f2f16101 100644 (file)
@@ -438,7 +438,7 @@ process_autolearn (struct statfile *st, struct worker_task *task, GTree * tokens
                                return;
                        }
 
-                       classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL, 1.);
+                       classifier->learn_func (ctx, task->worker->srv->statfile_pool, st->symbol, tokens, TRUE, NULL, 1., NULL);
                        maybe_write_binlog (ctx->cfg, st, statfile, tokens);
                }
        }