]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
* Improve logic of learning messages: do not learn more than specific threshold
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Mon, 2 Aug 2010 16:27:48 +0000 (20:27 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Mon, 2 Aug 2010 16:27:48 +0000 (20:27 +0400)
* Fix inserting results for symbols that were incorrectly (for example more than 1 time) defined in config file

src/classifiers/winnow.c
src/filter.c

index 7599b11503515a95ec3f77f2cd3dfe7f8d010418..481d3717d2d9d218ed2d1e804c68ff6c12cfadd2 100644 (file)
@@ -42,7 +42,7 @@
 
 #define MAX_WEIGHT G_MAXDOUBLE / 2.
 
-#define ALPHA 0.001
+#define ALPHA 0.01
 
 #define MAX_LEARN_ITERATIONS 100
 
@@ -55,6 +55,7 @@ struct winnow_callback_data {
        double                          multiplier;
        int                             count;
        gboolean                        in_class;
+       gboolean                        do_demote;
        gboolean                        fresh_run;
        time_t                          now;
 };
@@ -152,6 +153,11 @@ learn_callback (gpointer key, gpointer value, gpointer data)
                        }
                        statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
                }
+               else if (cd->do_demote) {
+                       /* Demote blocks in file */
+                       node->value *= WINNOW_DEMOTION * cd->multiplier;
+                       statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
+               }
        }
 
 
@@ -231,7 +237,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                }
 
                if (data.count != 0) {
-                       res = data.sum / data.count;
+                       res = data.sum / (double)data.count;
                }
                else {
                        res = 0;
@@ -251,7 +257,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
                        max = st->normalizer (task->cfg, max, st->normalizer_data);
                }
                sumbuf = memory_pool_alloc (task->task_pool, 32);
-               snprintf (sumbuf, 32, "%.2Lg", max);
+               rspamd_snprintf (sumbuf, 32, "%.2F", max);
                cur = g_list_prepend (NULL, sumbuf);
                insert_result (task, sel->symbol, max, cur);
        }
@@ -305,7 +311,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
 
                w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
                if (data.count != 0) {
-                       res = data.sum / data.count;
+                       res = data.sum / (double)data.count;
                }
                else {
                        res = 0;
@@ -334,10 +340,11 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
        };
        char                           *value;
        int                             nodes, minnodes, iterations = 0;
-       struct statfile                *st;
+       struct statfile                *st, *sel_st;
        stat_file_t                    *sel = NULL;
        long double                     res = 0., max = 0.;
-       GList                          *cur;
+       double                          learn_threshold = 1.0;
+       GList                          *cur, *to_demote = NULL;
 
        g_assert (pool != NULL);
        g_assert (ctx != NULL);
@@ -357,7 +364,67 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                        return;
                }
        }
+       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) {
+               learn_threshold = strtod (value, NULL);
+       }
        
+       if (learn_threshold >= 1.0) {
+               /* Classify message and check target statfile score */
+               cur = ctx->cfg->statfiles;
+               /* Check target statfile */
+               data.file = file;
+               data.sum = 0;
+               data.count = 0;
+               data.file = file;
+               statfile_pool_lock_file (pool, data.file);
+               g_tree_foreach (input, classify_callback, &data);
+               statfile_pool_unlock_file (pool, data.file);
+               if (data.count > 0) {
+                       max = data.sum / (double)data.count;
+               }
+               else {
+                       max = 0;
+               }
+               while (cur) {
+                       st = cur->data;
+                       data.sum = 0;
+                       data.count = 0;
+                       if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                               if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                                       msg_warn ("cannot open %s, skip it", st->path);
+                                       cur = g_list_next (cur);
+                                       continue;
+                               }
+                       }
+                       statfile_pool_lock_file (pool, data.file);
+                       g_tree_foreach (input, classify_callback, &data);
+                       statfile_pool_unlock_file (pool, data.file);
+                       if (data.count != 0) {
+                               res = data.sum / data.count;
+                       }
+                       else {
+                               res = 0;
+                       }
+                       if (file != data.file && res / max > learn_threshold) {
+                               /* Demote tokens in this statfile */
+                               to_demote = g_list_prepend (to_demote, data.file);
+                       }
+                       else if (file == data.file) {
+                               sel_st = st;
+                       }
+                       cur = g_list_next (cur);
+               }
+       }
+       else {
+               msg_err ("learn threshold is less than 1, so cannot do learn, please check your configuration");
+               return;
+       }
+       /* If to_demote list is empty this message is already classified correctly */
+       if (max > ALPHA && to_demote == NULL) {
+               msg_info ("this message is already of class %s with threshold %.2f and weight %.2F",
+                               sel_st->symbol, learn_threshold, max);
+               goto end;
+       }
        do {
                cur = ctx->cfg->statfiles;
                data.fresh_run = TRUE;
@@ -372,6 +439,12 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                                        continue;
                                }
                        }
+                       if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) {
+                               data.do_demote = TRUE;
+                       }
+                       else {
+                               data.do_demote = FALSE;
+                       }
                        statfile_pool_lock_file (pool, data.file);
                        g_tree_foreach (input, learn_callback, &data);
                        statfile_pool_unlock_file (pool, data.file);
@@ -402,11 +475,12 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                                file->filename, MAX_LEARN_ITERATIONS, max);
        }
        else {
-               msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations, max);
+               msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations + 1, max);
        }
 
 
+end:
        if (sum) {
-               *sum = max;
+               *sum = (double)max;
        }
 }
index f1851d29543119f116cdc466fb283c1f100d0a64..d5e127b7aa6e0948e189125d27fafcdef37b2935 100644 (file)
@@ -84,7 +84,7 @@ insert_metric_result (struct worker_task *task, struct metric *metric, const cha
        metric_res->score += w;
 
        if ((s = g_hash_table_lookup (metric_res->symbols, symbol)) != NULL) {
-               if (s->options && opts) {
+               if (s->options && opts && opts != s->options) {
                        /* Append new options */
                        s->options = g_list_concat (s->options, opts);
                        /*