]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
New chi2square based bayes normalizer.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 23 May 2013 15:15:46 +0000 (16:15 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 23 May 2013 15:15:46 +0000 (16:15 +0100)
src/classifiers/bayes.c
src/tokenizers/tokenizers.h

index 0e68f5d72327378f7908761ff0731c2d204cd739..f3ad365585c00ac48a2e4474f57f9aa751f3748c 100644 (file)
@@ -44,10 +44,7 @@ bayes_error_quark (void)
 struct bayes_statfile_data {
        guint64                         hits;
        guint64                         total_hits;
-       double                          local_probability;
-       double                          post_probability;
-       double                          corr;
-       double                          value;
+       double                         value;
        struct statfile                *st;
        stat_file_t                    *file;
 };
@@ -60,8 +57,11 @@ struct bayes_callback_data {
        stat_file_t                    *file;
        struct bayes_statfile_data     *statfiles;
        guint32                         statfiles_num;
-       guint64                         learned_tokens;
+       guint64                                                 total_spam;
+       guint64                                                 total_ham;
+       guint64                         processed_tokens;
        gsize                           max_tokens;
+       double                         spam_probability;
 };
 
 static                          gboolean
@@ -78,7 +78,7 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
        v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
        if (v == 0 && c > 0) {
                statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
-               cd->learned_tokens ++;
+               cd->processed_tokens ++;
        }
        else if (v != 0) {
                if (G_LIKELY (c > 0)) {
@@ -90,16 +90,50 @@ bayes_learn_callback (gpointer key, gpointer value, gpointer data)
                        }
                }
                statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, v);
-               cd->learned_tokens ++;
+               cd->processed_tokens ++;
        }
 
-       if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) {
+       if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
                /* Stop learning on max tokens */
                return TRUE;
        }
        return FALSE;
 }
 
+/**
+ * Returns probability of chisquare > value with specified number of freedom
+ * degrees
+ * @param value value to test
+ * @param freedom_deg number of degrees of freedom
+ * @return
+ */
+static gdouble
+inv_chi_square (gdouble value, gint freedom_deg)
+{
+       gdouble prob, sum;
+       gint i;
+
+       if ((freedom_deg & 1) != 0) {
+               msg_err ("non-odd freedom degrees count: %d", freedom_deg);
+               return 0;
+       }
+
+       value /= 2.;
+       errno = 0;
+       prob = exp (-value);
+       if (errno == ERANGE) {
+               msg_err ("exp overflow");
+               return 0;
+       }
+       sum = prob;
+       for (i = 1; i < freedom_deg / 2; i ++) {
+               prob *= value / (gdouble)i;
+               sum += prob;
+       }
+
+       return MIN (1.0, sum);
+}
+
 /*
  * In this callback we calculate local probabilities for tokens
  */
@@ -107,57 +141,39 @@ static gboolean
 bayes_classify_callback (gpointer key, gpointer value, gpointer data)
 {
 
-       token_node_t                   *node = key;
+       token_node_t                    *node = key;
        struct bayes_callback_data     *cd = data;
-       double                          renorm = 0;
        guint                            i;
-       double                          local_hits = 0;
        struct bayes_statfile_data     *cur;
+       guint64                                                  spam_count = 0, ham_count = 0, total_count = 0;
+       double                                                   spam_prob, spam_freq, ham_freq, bayes_spam_prob;
 
        for (i = 0; i < cd->statfiles_num; i ++) {
                cur = &cd->statfiles[i];
                cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now);
                if (cur->value > 0) {
-                       cur->total_hits ++;
-                       cur->hits = cur->value;
-                       local_hits += cur->value;
+                       cur->total_hits += cur->value;
+                       if (cur->st->is_spam) {
+                               spam_count ++;
+                       }
+                       else {
+                               ham_count ++;
+                       }
+                       total_count ++;
                }
        }
-       for (i = 0; i < cd->statfiles_num; i ++) {
-               cur = &cd->statfiles[i];
-               cur->local_probability = 0.5 + (cur->value - (local_hits - cur->value)) /
-                               (LOCAL_PROB_DENOM * (1.0 + local_hits));
-               renorm += cur->post_probability * cur->local_probability;
-       }
-
-       for (i = 0; i < cd->statfiles_num; i ++) {
-               cur = &cd->statfiles[i];
-               cur->post_probability = (cur->post_probability * cur->local_probability) / renorm;
-               if (cur->post_probability < G_MINDOUBLE * 100) {
-                       cur->post_probability = G_MINDOUBLE * 100;
-               }
 
-       }
-       renorm = 0;
-       for (i = 0; i < cd->statfiles_num; i ++) {
-               cur = &cd->statfiles[i];
-               renorm += cur->post_probability;
-       }
-       /* Renormalize to form sum of probabilities equal to 1 */
-       for (i = 0; i < cd->statfiles_num; i ++) {
-               cur = &cd->statfiles[i];
-               cur->post_probability /= renorm;
-               if (cur->post_probability < G_MINDOUBLE * 10) {
-                       cur->post_probability = G_MINDOUBLE * 100;
-               }
-               if (cd->ctx->debug) {
-                       msg_info ("token: %s, statfile: %s, probability: %.4f, post_probability: %.4f",
-                                       node->extra, cur->st->symbol, cur->value, cur->post_probability);
-               }
+       /* Probability for this token */
+       if (total_count > 0) {
+               spam_freq = ((double)spam_count / MAX (1., (double)cd->total_spam));
+               ham_freq = ((double)ham_count / MAX (1., (double)cd->total_ham));
+               spam_prob = spam_freq / (spam_freq + ham_freq);
+               bayes_spam_prob = (0.5 + spam_prob * total_count) / (double)total_count;
+               cd->spam_probability += log (bayes_spam_prob);
+               cd->processed_tokens ++;
        }
 
-       cd->learned_tokens ++;
-       if (cd->max_tokens != 0 && cd->learned_tokens > cd->max_tokens) {
+       if (cd->max_tokens != 0 && cd->processed_tokens > cd->max_tokens) {
                /* Stop classifying on max tokens */
                return TRUE;
        }
@@ -181,15 +197,15 @@ gboolean
 bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task, lua_State *L)
 {
        struct bayes_callback_data      data;
-       gchar                          *value;
-       gint                            nodes, i = 0, cnt, best_num_spam = 0, best_num_ham = 0;
-       gint                            minnodes;
-       guint64                         rev, total_learns = 0;
-       double                          best_spam = 0., best_ham = 0., total_spam = 0., total_ham = 0.;
+       gchar                           *value;
+       gint                             nodes, i = 0, selected_st = -1, cnt;
+       gint                             minnodes;
+       guint64                          maxhits = 0;
+       double                          final_prob;
        struct statfile                *st;
-       stat_file_t                    *file;
-       GList                          *cur;
-       char                           *sumbuf;
+       stat_file_t                     *file;
+       GList                           *cur;
+       char                            *sumbuf;
 
        g_assert (pool != NULL);
        g_assert (ctx != NULL);
@@ -219,7 +235,10 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
        data.now = time (NULL);
        data.ctx = ctx;
 
-       data.learned_tokens = 0;
+       data.processed_tokens = 0;
+       data.spam_probability = 0;
+       data.total_ham = 0;
+       data.total_spam = 0;
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
                minnodes = parse_limit (value, -1);
                data.max_tokens = minnodes;
@@ -241,10 +260,12 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
                }
                data.statfiles[i].file = file;
                data.statfiles[i].st = st;
-               data.statfiles[i].post_probability = 0.5;
-               data.statfiles[i].local_probability = 0.5;
-               statfile_get_revision (file, &rev, NULL);
-               total_learns += rev;
+               if (st->is_spam) {
+                       data.total_spam += statfile_get_used_blocks (file);
+               }
+               else {
+                       data.total_ham += statfile_get_used_blocks (file);
+               }
 
                cur = g_list_next (cur);
                i ++;
@@ -252,46 +273,39 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
 
        cnt = i;
 
-       /* Calculate correction factor */
-       for (i = 0; i < cnt; i ++) {
-               statfile_get_revision (data.statfiles[i].file, &rev, NULL);
-               data.statfiles[i].corr = ((double)rev / cnt) / (double)total_learns;
-       }
-
        g_tree_foreach (input, bayes_classify_callback, &data);
 
-       for (i = 0; i < cnt; i ++) {
-               debug_task ("got probability for symbol %s: %.2f", data.statfiles[i].st->symbol, data.statfiles[i].post_probability);
-
-               if (data.statfiles[i].st->is_spam) {
-                       total_spam += data.statfiles[i].post_probability;
-                       if (data.statfiles[i].post_probability > best_spam) {
-                               best_spam = data.statfiles[i].post_probability;
-                               best_num_spam = i;
-                       }
-               }
-               else {
-                       total_ham += data.statfiles[i].post_probability;
-                       if (data.statfiles[i].post_probability > best_ham) {
-                               best_ham = data.statfiles[i].post_probability;
-                               best_num_ham = i;
-                       }
-               }
+       if (data.processed_tokens == 0 || data.spam_probability == 0) {
+               final_prob = 0;
+       }
+       else {
+               final_prob = inv_chi_square (-2. * data.spam_probability, 2 * data.processed_tokens);
        }
 
-
-       if (total_ham > 0.5 || total_spam > 0.5) {
+       if (final_prob > 0 && fabs (final_prob - 0.5) > 0.0001) {
 
                sumbuf = memory_pool_alloc (task->task_pool, 32);
-               if (total_ham > total_spam) {
-                       rspamd_snprintf (sumbuf, 32, "%.2f", total_ham);
-                       cur = g_list_prepend (NULL, sumbuf);
-                       insert_result (task, data.statfiles[best_num_ham].st->symbol, total_ham, cur);
+               for (i = 0; i < cnt; i ++) {
+                       if ((final_prob > 0.5 && !data.statfiles[i].st->is_spam) ||
+                                       (final_prob < 0.5 && data.statfiles[i].st->is_spam)) {
+                               continue;
+                       }
+                       if (data.statfiles[i].total_hits > maxhits) {
+                               maxhits = data.statfiles[i].total_hits;
+                               selected_st = i;
+                       }
+               }
+               if (selected_st == -1) {
+                       msg_err ("unexpected classifier error: cannot select desired statfile");
                }
                else {
-                       rspamd_snprintf (sumbuf, 32, "%.2f", total_spam);
+                       /* Calculate ham probability correctly */
+                       if (final_prob < 0.5) {
+                               final_prob = 1. - final_prob;
+                       }
+                       rspamd_snprintf (sumbuf, 32, "%.2f", final_prob);
                        cur = g_list_prepend (NULL, sumbuf);
-                       insert_result (task, data.statfiles[best_num_spam].st->symbol, total_spam, cur);
+                       insert_result (task, data.statfiles[selected_st].st->symbol, final_prob, cur);
                }
        }
 
@@ -337,8 +351,8 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
        data.in_class = in_class;
        data.now = time (NULL);
        data.ctx = ctx;
-       data.learned_tokens = 0;
-       data.learned_tokens = 0;
+       data.processed_tokens = 0;
+       data.processed_tokens = 0;
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
                minnodes = parse_limit (value, -1);
                data.max_tokens = minnodes;
@@ -394,7 +408,7 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
        statfile_pool_unlock_file (pool, data.file);
 
        if (sum != NULL) {
-               *sum = data.learned_tokens;
+               *sum = data.processed_tokens;
        }
 
        return TRUE;
@@ -447,7 +461,7 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
        data.now = time (NULL);
        data.ctx = ctx;
 
-       data.learned_tokens = 0;
+       data.processed_tokens = 0;
        if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "max_tokens")) != NULL) {
                minnodes = parse_limit (value, -1);
                data.max_tokens = minnodes;
@@ -503,70 +517,6 @@ bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
 GList *
 bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
 {
-       struct bayes_callback_data      data;
-       char                           *value;
-       int                             nodes, minnodes, i, cnt;
-       struct classify_weight         *w;
-       struct statfile                *st;
-       stat_file_t                    *file;
-       GList                          *cur, *resl = NULL;
-
-       g_assert (pool != NULL);
-       g_assert (ctx != NULL);
-
-       if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
-               minnodes = strtol (value, NULL, 10);
-               nodes = g_tree_nnodes (input);
-               if (nodes > FEATURE_WINDOW_SIZE) {
-                       nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
-               }
-               if (nodes < minnodes) {
-                       return NULL;
-               }
-       }
-
-       data.statfiles_num = g_list_length (ctx->cfg->statfiles);
-       data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
-       data.pool = pool;
-       data.now = time (NULL);
-       data.ctx = ctx;
-
-       cur = ctx->cfg->statfiles;
-       i = 0;
-       while (cur) {
-               /* Select statfile to learn */
-               st = cur->data;
-               if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
-                       if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
-                               msg_warn ("cannot open %s", st->path);
-                               cur = g_list_next (cur);
-                               data.statfiles_num --;
-                               continue;
-                       }
-               }
-               data.statfiles[i].file = file;
-               data.statfiles[i].st = st;
-               data.statfiles[i].post_probability = 0.5;
-               data.statfiles[i].local_probability = 0.5;
-               i ++;
-               cur = g_list_next (cur);
-       }
-       cnt = i;
-
-       g_tree_foreach (input, bayes_classify_callback, &data);
-
-       for (i = 0; i < cnt; i ++) {
-               w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
-               w->name = data.statfiles[i].st->symbol;
-               w->weight = data.statfiles[i].post_probability;
-               resl = g_list_prepend (resl, w);
-       }
-
-       g_free (data.statfiles);
-
-       if (resl != NULL) {
-               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
-       }
-
-       return resl;
+       /* This function is unimplemented with new normalizer */
+       return NULL;
 }
index be0daac9b2c1a647ed5caaf9b5ca84d1a668b3db..51893ca4310df654b90e20ccd9cd21c81640db35 100644 (file)
@@ -12,7 +12,7 @@
 typedef struct token_node_s {
        guint32 h1;
        guint32 h2;
-       float value;
+       double value;
        uintptr_t extra;
 } token_node_t;