]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Rework language detection ngramms structure
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 16 Jan 2018 08:00:48 +0000 (08:00 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 16 Jan 2018 08:00:48 +0000 (08:00 +0000)
src/libmime/lang_detection.c

index 4daed73d13e86cf6dc405f7a91b9f770af34e218..9c174ca23eb7d09e7cb4bcd83b4cc1be9516a737 100644 (file)
@@ -29,18 +29,29 @@ static const gsize default_words = 30;
 static const gdouble update_prob = 0.6;
 static const gchar *default_languages_path = RSPAMD_PLUGINSDIR "/languages";
 
+enum rspamd_language_elt_flags {
+       RS_LANGUAGE_DEFAULT = 0,
+       RS_LANGUAGE_LATIN = (1 <<0),
+};
+
 struct rspamd_language_elt {
        const gchar *name; /* e.g. "en" or "ru" */
+       enum rspamd_language_elt_flags flags;
        guint unigramms_total; /* total frequencies for unigramms */
-       GHashTable *unigramms; /* unigramms frequencies */
        guint bigramms_total; /* total frequencies for bigramms */
-       GHashTable *bigramms; /* bigramms frequencies */
        guint trigramms_total; /* total frequencies for trigramms */
-       GHashTable *trigramms; /* trigramms frequencies */
+};
+
+struct rspamd_ngramm_elt {
+       struct rspamd_language_elt *elt;
+       gdouble prob;
 };
 
 struct rspamd_lang_detector {
        GPtrArray *languages;
+       GHashTable *unigramms; /* unigramms frequencies */
+       GHashTable *bigramms; /* bigramms frequencies */
+       GHashTable *trigramms; /* trigramms frequencies */
        UConverter *uchar_converter;
        gsize short_text_limit;
 };
@@ -96,6 +107,82 @@ rspamd_language_detector_ucs_lowercase (UChar *s, gsize len)
        }
 }
 
+static gboolean
+rspamd_language_detector_ucs_is_latin (UChar *s, gsize len)
+{
+       gsize i;
+       gboolean ret = TRUE;
+
+       for (i = 0; i < len; i ++) {
+               if (!u_hasBinaryProperty (s[i], UCHAR_POSIX_ALNUM)) {
+                       ret = FALSE;
+                       break;
+               }
+       }
+
+       return ret;
+}
+
+static void
+rspamd_language_detector_init_ngramm (struct rspamd_config *cfg,
+               struct rspamd_lang_detector *d,
+               struct rspamd_language_elt *lelt,
+               UChar *s, guint len, guint freq, guint total)
+{
+       GHashTable *target;
+       GPtrArray *ar;
+       struct rspamd_ngramm_elt *elt;
+       guint i;
+       gboolean found;
+
+       switch (len) {
+       case 1:
+               target = d->unigramms;
+               break;
+       case 2:
+               target = d->bigramms;
+               break;
+       case 3:
+               target = d->trigramms;
+               break;
+       default:
+               g_assert_not_reached ();
+               break;
+       }
+
+       ar = g_hash_table_lookup (target, s);
+
+       if (ar == NULL) {
+               /* New element */
+               ar = g_ptr_array_sized_new (32);
+               elt = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (*elt));
+               elt->elt = lelt;
+               elt->prob = ((gdouble)freq) / ((gdouble)total);
+               g_ptr_array_add (ar, elt);
+
+               g_hash_table_insert (target, s, ar);
+       }
+       else {
+               /* Check sanity */
+               found = FALSE;
+
+               PTR_ARRAY_FOREACH (ar, i, elt) {
+                       if (strcmp (elt->elt->name, lelt->name) == 0) {
+                               found = TRUE;
+                               elt->prob += ((gdouble)freq) / ((gdouble)total);
+                               break;
+                       }
+               }
+
+               if (!found) {
+                       elt = rspamd_mempool_alloc (cfg->cfg_pool, sizeof (*elt));
+                       elt->elt = lelt;
+                       elt->prob = ((gdouble)freq) / ((gdouble)total);
+                       g_ptr_array_add (ar, elt);
+               }
+       }
+}
+
 static void
 rspamd_language_detector_read_file (struct rspamd_config *cfg,
                struct rspamd_lang_detector *d,
@@ -108,6 +195,7 @@ rspamd_language_detector_read_file (struct rspamd_config *cfg,
        UErrorCode uc_err = U_ZERO_ERROR;
        struct rspamd_language_elt *nelt;
        gchar *pos;
+       guint total = 0, total_latin = 0, total_ngramms = 0;
 
        parser = ucl_parser_new (UCL_PARSER_NO_FILEVARS);
        if (!ucl_parser_add_file (parser, path)) {
@@ -138,9 +226,24 @@ rspamd_language_detector_read_file (struct rspamd_config *cfg,
        pos = strchr (nelt->name, '.');
        g_assert (pos != NULL);
        *pos = '\0';
-       nelt->unigramms = g_hash_table_new (rspamd_unigram_hash, rspamd_unigram_equal);
-       nelt->bigramms = g_hash_table_new (rspamd_bigram_hash, rspamd_bigram_equal);
-       nelt->trigramms = g_hash_table_new (rspamd_trigram_hash, rspamd_trigram_equal);
+
+       n_words = ucl_object_lookup (top, "n_words");
+
+       if (n_words == NULL || ucl_object_type (n_words) != UCL_ARRAY ||
+                       n_words->len != 3) {
+               msg_warn_config ("cannot find n_words in language %s", nelt->name);
+               ucl_object_unref (top);
+
+               return;
+       }
+       else {
+               nelt->unigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
+                               0));
+               nelt->bigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
+                               1));
+               nelt->trigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
+                               2));
+       }
 
        while ((cur = ucl_object_iterate (freqs, &it, true)) != NULL) {
                const gchar *key;
@@ -166,49 +269,41 @@ rspamd_language_detector_read_file (struct rspamd_config *cfg,
                        }
 
                        rspamd_language_detector_ucs_lowercase (ucs_key, nsym);
-
                        if (nsym == 2) {
                                /* We have a digraph */
-                               g_hash_table_insert (nelt->bigramms, ucs_key,
-                                               GUINT_TO_POINTER (freq));
-                               nelt->bigramms_total += freq;
+                               total = nelt->bigramms_total;
                        }
                        else if (nsym == 3) {
-                               g_hash_table_insert (nelt->trigramms, ucs_key,
-                                               GUINT_TO_POINTER (freq));
-                               nelt->trigramms_total += freq;
+                               total = nelt->trigramms_total;
                        }
                        else if (nsym == 1) {
-                               g_hash_table_insert (nelt->unigramms, ucs_key,
-                                               GUINT_TO_POINTER (freq));
-                               nelt->unigramms_total += freq;
+                               total = nelt->unigramms_total;
                        }
                        else if (nsym > 3) {
                                msg_warn_config ("have more than 3 characters in key: %d", nsym);
+                               continue;
                        }
-               }
-       }
 
-       n_words = ucl_object_lookup (top, "n_words");
+                       rspamd_language_detector_init_ngramm (cfg, d, nelt, ucs_key, nsym,
+                                       freq, total);
 
-       if (n_words == NULL || ucl_object_type (n_words) != UCL_ARRAY ||
-                       n_words->len != 3) {
-               msg_warn_config ("cannot find n_words in language %s", nelt->name);
+                       if (rspamd_language_detector_ucs_is_latin (ucs_key, nsym)) {
+                               total_latin ++;
+                       }
+
+                       total_ngramms ++;
+               }
        }
-       else {
-               nelt->unigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
-                               0));
-               nelt->bigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
-                               1));
-               nelt->trigramms_total = ucl_object_toint (ucl_array_find_index (n_words,
-                               2));
+
+       if (total_latin >= total_ngramms * 2 / 3) {
+               nelt->flags |= RS_LANGUAGE_LATIN;
        }
 
        msg_info_config ("loaded %s language, %d unigramms, %d digramms, %d trigramms",
                        nelt->name,
-                       (gint)g_hash_table_size (nelt->unigramms),
-                       (gint)g_hash_table_size (nelt->bigramms),
-                       (gint)g_hash_table_size (nelt->trigramms));
+                       (gint)nelt->unigramms_total,
+                       (gint)nelt->bigramms_total,
+                       (gint)nelt->trigramms_total);
 
        g_ptr_array_add (d->languages, nelt);
        ucl_object_unref (top);
@@ -254,6 +349,13 @@ rspamd_language_detector_init (struct rspamd_config *cfg)
        ret->languages = g_ptr_array_sized_new (gl.gl_pathc);
        ret->uchar_converter = ucnv_open ("UTF-8", &uc_err);
        ret->short_text_limit = short_text_limit;
+       /* Map from ngramm in ucs32 to GPtrArray of rspamd_language_elt */
+       ret->unigramms = g_hash_table_new_full (rspamd_unigram_hash,
+                       rspamd_unigram_equal, NULL, rspamd_ptr_array_free_hard);
+       ret->bigramms = g_hash_table_new_full (rspamd_bigram_hash,
+                       rspamd_bigram_equal, NULL, rspamd_ptr_array_free_hard);
+       ret->trigramms = g_hash_table_new_full (rspamd_trigram_hash,
+                       rspamd_trigram_equal, NULL, rspamd_ptr_array_free_hard);
 
        g_assert (uc_err == U_ZERO_ERROR);
 
@@ -417,146 +519,43 @@ rspamd_language_detector_process_ngramm_full (struct rspamd_task *task,
                GHashTable *candidates)
 {
        guint i;
-       gdouble freq, class_freq;
-       struct rspamd_language_elt *elt;
+       GPtrArray *ar;
+       struct rspamd_ngramm_elt *elt;
        struct rspamd_lang_detector_res *cand;
        GHashTable *ngramms;
 
-       for (i = 0; i < d->languages->len; i ++) {
-               elt = g_ptr_array_index (d->languages, i);
-
-               switch (type) {
-               case rs_unigramm:
-                       ngramms = elt->unigramms;
-                       class_freq = elt->unigramms_total;
-                       break;
-               case rs_bigramm:
-                       ngramms = elt->bigramms;
-                       class_freq = elt->bigramms_total;
-                       break;
-               case rs_trigramm:
-                       ngramms = elt->trigramms;
-                       class_freq = elt->trigramms_total;
-                       break;
-               }
-
-               freq = ((gdouble)GPOINTER_TO_UINT (
-                               g_hash_table_lookup (ngramms, window))) / class_freq;
-
-               if (freq > 0) {
-                       cand = g_hash_table_lookup (candidates, elt->name);
-
-                       if (cand == NULL) {
-                               cand = g_malloc (sizeof (*cand));
-                               cand->elt = elt;
-                               cand->lang = elt->name;
-                               cand->prob = freq;
-
-                               g_hash_table_insert (candidates, (gpointer)elt->name, cand);
-                       } else {
-                               /* Update guess */
-                               cand->prob += freq;
-                       }
-               }
-       }
-}
-
-/*
- * Check only candidates, if none found, switch to full version
- */
-static gboolean
-rspamd_language_detector_process_ngramm_update (struct rspamd_task *task,
-               struct rspamd_lang_detector *d,
-               UChar *window, enum rspamd_language_gramm_type type,
-               GHashTable *candidates)
-{
-       gdouble freq, total_freq = 0.0, class_freq;
-       struct rspamd_language_elt *elt;
-       struct rspamd_lang_detector_res *cand;
-       GHashTableIter it;
-       gpointer k, v;
-       GHashTable *ngramms;
-
-       g_hash_table_iter_init (&it, candidates);
-
-       while (g_hash_table_iter_next (&it, &k, &v)) {
-               cand = (struct rspamd_lang_detector_res *)v;
-               elt = cand->elt;
-
-               switch (type) {
-               case rs_unigramm:
-                       ngramms = elt->unigramms;
-                       class_freq = elt->unigramms_total;
-                       break;
-               case rs_bigramm:
-                       ngramms = elt->bigramms;
-                       class_freq = elt->bigramms_total;
-                       break;
-               case rs_trigramm:
-                       ngramms = elt->trigramms;
-                       class_freq = elt->trigramms_total;
-                       break;
-               }
-
-               freq = ((gdouble)GPOINTER_TO_UINT (
-                               g_hash_table_lookup (ngramms, window))) / class_freq;
-
-               cand->prob += freq;
-               total_freq += freq;
-       }
-
-       if (total_freq == 0) {
-               /* Nothing found , do full scan which will also update candidates */
-               rspamd_language_detector_process_ngramm_full (task, d, window,
-                               type, candidates);
-
-               return FALSE;
-       }
-
-       return TRUE;
-}
-
-static gboolean
-rspamd_language_detector_update_guess (struct rspamd_task *task,
-               struct rspamd_lang_detector *d,
-               rspamd_stat_token_t *tok, GHashTable *candidates,
-               enum rspamd_language_gramm_type type)
-{
-       guint wlen;
-       UChar window[3];
-       goffset cur = 0;
-       gboolean ret = TRUE;
-
        switch (type) {
        case rs_unigramm:
-               wlen = 1;
+               ngramms = d->unigramms;
                break;
        case rs_bigramm:
-               wlen = 2;
+               ngramms = d->bigramms;
                break;
        case rs_trigramm:
-               wlen = 3;
+               ngramms = d->trigramms;
                break;
        }
 
-       /* Split words */
-       while ((cur = rspamd_language_detector_next_ngramm (tok, window, wlen, cur))
-                       != -1) {
 
-               if (rspamd_random_double_fast () > update_prob) {
-                       if (!rspamd_language_detector_process_ngramm_update (task, d, window,
-                                       type, candidates)) {
-                               ret = FALSE;
+       ar = g_hash_table_lookup (ngramms, window);
+
+       if (ar) {
+               PTR_ARRAY_FOREACH (ar, i, elt) {
+                       cand = g_hash_table_lookup (candidates, elt->elt->name);
+
+                       if (cand == NULL) {
+                               cand = g_malloc (sizeof (*cand));
+                               cand->elt = elt->elt;
+                               cand->lang = elt->elt->name;
+                               cand->prob = elt->prob;
+
+                               g_hash_table_insert (candidates, (gpointer)cand->lang, cand);
+                       } else {
+                               /* Update guess */
+                               cand->prob += elt->prob;
                        }
                }
-               else {
-                       /* Try to do full update in case if we are missing some candidates */
-                       rspamd_language_detector_process_ngramm_full (task, d, window, type,
-                                       candidates);
-               }
        }
-
-       return ret;
 }
 
 static void
@@ -647,8 +646,7 @@ rspamd_language_detector_detect_type (struct rspamd_task *task,
                struct rspamd_lang_detector *d,
                GArray *ucs_tokens,
                GHashTable *candidates,
-               enum rspamd_language_gramm_type type,
-               gboolean start_over)
+               enum rspamd_language_gramm_type type)
 {
        guint nparts = MIN (ucs_tokens->len, nwords);
        goffset *selected_words;
@@ -662,16 +660,11 @@ rspamd_language_detector_detect_type (struct rspamd_task *task,
        /* Deal with the first word in a special case */
        tok = &g_array_index (ucs_tokens, rspamd_stat_token_t, selected_words[0]);
 
-       if (start_over) {
-               rspamd_language_detector_detect_word (task, d, tok, candidates, type);
-       }
-       else {
-               rspamd_language_detector_update_guess (task, d, tok, candidates, type);
-       }
+       rspamd_language_detector_detect_word (task, d, tok, candidates, type);
 
        for (i = 1; i < nparts; i ++) {
                tok = &g_array_index (ucs_tokens, rspamd_stat_token_t, selected_words[i]);
-               rspamd_language_detector_update_guess (task, d, tok, candidates, type);
+               rspamd_language_detector_detect_word (task, d, tok, candidates, type);
        }
 
        /* Filter negligible candidates */
@@ -711,8 +704,12 @@ rspamd_language_detector_try_ngramm (struct rspamd_task *task,
 {
        guint cand_len;
 
-       rspamd_language_detector_detect_type (task, nwords, d, ucs_tokens, candidates,
-                       type, TRUE);
+       rspamd_language_detector_detect_type (task,
+                       nwords,
+                       d,
+                       ucs_tokens,
+                       candidates,
+                       type);
 
        cand_len = g_hash_table_size (candidates);