]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
* Implement new learning system, now rspamd should be much more intelligent while...
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 27 May 2010 13:33:31 +0000 (17:33 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 27 May 2010 13:33:31 +0000 (17:33 +0400)
src/classifiers/winnow.c
src/tokenizers/osb.c
src/tokenizers/tokenizers.h

index a5e7b3cf8bb3e424ba35df0445e772efed5946da..637be759d01e96f1ddb0f84d0bb3f4beb5ba118f 100644 (file)
 #define WINNOW_PROMOTION 1.23
 #define WINNOW_DEMOTION 0.83
 
+#define MEDIAN_WINDOW_SIZE 5
+
+#define MAX_WEIGHT G_MAXDOUBLE / 2.
+
+#define ALPHA 0.001
+
+#define MAX_LEARN_ITERATIONS 100
+
 struct winnow_callback_data {
        statfile_pool_t                *pool;
        struct classifier_ctx          *ctx;
        stat_file_t                    *file;
+       stat_file_t                    *learn_file;
        double                          sum;
        double                          multiplier;
        int                             count;
-       int                             in_class;
+       gboolean                        in_class;
+       gboolean                        fresh_run;
        time_t                          now;
 };
 
+static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION;
+
 static                          gboolean
 classify_callback (gpointer key, gpointer value, gpointer data)
 {
@@ -58,9 +70,9 @@ classify_callback (gpointer key, gpointer value, gpointer data)
 
        /* Consider that not found blocks have value 1 */
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
-       if (fabs (v) > 0.00001) {
-        if (cd->sum + v > G_MAXDOUBLE / 2.) {
-            cd->sum = G_MAXDOUBLE / 2.;
+       if (fabs (v) > ALPHA) {
+        if (cd->sum + v > MAX_WEIGHT) {
+            cd->sum = MAX_WEIGHT;
         }
         else {
                    cd->sum += v;
@@ -78,31 +90,78 @@ learn_callback (gpointer key, gpointer value, gpointer data)
 {
        token_node_t                   *node = key;
        struct winnow_callback_data    *cd = data;
-       double                           v, c;
+       double                          v, c;
        
        c = (cd->in_class) ? WINNOW_PROMOTION : WINNOW_DEMOTION;
        c *= cd->multiplier;
 
        /* Consider that not found blocks have value 1 */
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
-       if (fabs (v) < 0.00001) {
-               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
-               node->value = c;
+       if (fabs (v) < ALPHA) {
+               /* Block not found, insert new */
+               if (cd->file == cd->learn_file) {
+                       statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
+                       node->value = c;
+               }
        }
        else {
-               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v * c);
-        /* Set some limit on growing */
-        if (v > G_MAXDOUBLE / 2.) {
-            node->value = v;
-        }
-        else {
-                   node->value = v * c;
-        }
+               /* Here we just increase the extra value of block */
+               if (cd->fresh_run) {
+                       node->extra = 0;
+               }
+               else {
+                       node->extra ++;
+               }
+               node->value = v;
+               
+               if (node->extra > 1) {
+                       /* 
+                        * Assume that this node is common for several statfiles, so
+                        * decrease its weight proportianally
+                        */
+                       if (node->value > max_common_weight) {
+                               /* Static fluctuation */
+                               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, 0.);
+                               node->value = 0.;
+                       }
+                       else if (node->value > WINNOW_PROMOTION) {
+                               /* Try to decrease its value */
+                               /* XXX: it is more intelligent to add some adaptive filter here */
+                               if (cd->file == cd->learn_file) {
+                                       if (node->value > max_common_weight / 2.) {
+                                               node->value *= c;
+                                       }
+                                       else {
+                                               /* 
+                                                * Too high token value that exists also in other
+                                                * statfiles, may be statistic error, so decrease it
+                                                * slightly
+                                                */
+                                               node->value *= WINNOW_DEMOTION * cd->multiplier;
+                                       }
+                               }
+                               else {
+                                       node->value = sqrt (node->value);
+                               }
+                               statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
+                       } 
+               }
+               else if (cd->file == cd->learn_file) {
+                       /* New block or block that is in only one statfile */
+                       /* Set some limit on growing */
+                       if (v > MAX_WEIGHT) {
+                               node->value = v;
+                       }
+                       else {
+                               node->value *= c;
+                       }
+                       statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
+               }
        }
 
 
-    if (cd->sum + node->value > G_MAXDOUBLE / 2.) {
-        cd->sum = G_MAXDOUBLE / 2.;
+    if (cd->sum + node->value > MAX_WEIGHT) {
+        cd->sum = MAX_WEIGHT;
     }
     else {
            cd->sum += node->value;
@@ -223,8 +282,6 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
        g_assert (ctx != NULL);
 
        data.pool = pool;
-       data.sum = 0;
-       data.count = 0;
        data.now = time (NULL);
        data.ctx = ctx;
 
@@ -240,6 +297,8 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
        cur = ctx->cfg->statfiles;
        while (cur) {
                st = cur->data;
+               data.sum = 0;
+               data.count = 0;
                if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
                        if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
                                msg_warn ("cannot open %s, skip it", st->path);
@@ -254,7 +313,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
                        statfile_pool_unlock_file (pool, data.file);
                }
 
-               w = memory_pool_alloc (task->task_pool, sizeof (struct classify_weight));
+               w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
                if (data.count != 0) {
                        res = data.sum / data.count;
                }
@@ -281,12 +340,14 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
 {
        struct winnow_callback_data     data = {
                .file = NULL,
-               .sum = 0,
-               .count = 0,
                .multiplier = multiplier
        };
        char                           *value;
-       int                             nodes, minnodes;
+       int                             nodes, minnodes, iterations = 0;
+       struct statfile                *st;
+       stat_file_t                    *sel;
+       double                          res = 0., max = 0.;
+       GList                          *cur;
 
        g_assert (pool != NULL);
        g_assert (ctx != NULL);
@@ -295,8 +356,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
        data.in_class = in_class;
        data.now = time (NULL);
        data.ctx = ctx;
-
-       data.file = file;
+       data.learn_file = file;
 
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
                minnodes = strtol (value, NULL, 10);
@@ -307,12 +367,45 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
                        return;
                }
        }
-
-       if (data.file != NULL) {
-               statfile_pool_lock_file (pool, data.file);
-               g_tree_foreach (input, learn_callback, &data);
-               statfile_pool_unlock_file (pool, data.file);
-       }
+       
+       do {
+               cur = ctx->cfg->statfiles;
+               data.fresh_run = TRUE;
+               while (cur) {
+                       st = cur->data;
+                       data.sum = 0;
+                       data.count = 0;
+                       if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
+                               if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+                                       msg_warn ("cannot open %s, skip it", st->path);
+                                       cur = g_list_next (cur);
+                                       continue;
+                               }
+                       }
+                       statfile_pool_lock_file (pool, data.file);
+                       g_tree_foreach (input, learn_callback, &data);
+                       statfile_pool_unlock_file (pool, data.file);
+                       if (data.count != 0) {
+                               res = data.sum / data.count;
+                       }
+                       else {
+                               res = 0;
+                       }
+                       if (res > max) {
+                               max = res;
+                               sel = data.file;
+                       }
+                       cur = g_list_next (cur);
+                       data.fresh_run = FALSE;
+               }
+               
+               if (data.multiplier > 1) {
+                       data.multiplier *= data.multiplier;
+               }
+               else {
+                       data.multiplier *= WINNOW_PROMOTION;
+               }
+       } while ((in_class ? sel != file : sel == file)  && iterations ++ < MAX_LEARN_ITERATIONS);
        
        if (sum) {
                if (data.count != 0) {
index d36047efde93b8a1836ff7f4640705dd1effa0d7..ae59cf8ea70bdeeed1cf7609165cba79a254ac69 100644 (file)
@@ -66,7 +66,7 @@ osb_tokenize_text (struct tokenizer *tokenizer, memory_pool_t * pool, f_str_t *
                for (i = 1; i < FEATURE_WINDOW_SIZE; i++) {
                        h1 = hashpipe[0] * primes[0] + hashpipe[i] * primes[i << 1];
                        h2 = hashpipe[0] * primes[1] + hashpipe[i] * primes[(i << 1) - 1];
-                       new = memory_pool_alloc (pool, sizeof (token_node_t));
+                       new = memory_pool_alloc0 (pool, sizeof (token_node_t));
                        new->h1 = h1;
                        new->h2 = h2;
 
index fda5bded340c9e38aedf9ee31594f4a1fdc934f9..9a16e907c6078f331e6818f1a352a5ccdf76d76b 100644 (file)
@@ -18,6 +18,7 @@ typedef struct token_node_s {
        uint32_t h1;
        uint32_t h2;
        float value;
+       uintptr_t extra;
 } token_node_t;
 
 /* Common tokenizer structure */