]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Implement fasttext language detection
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 29 Apr 2023 14:47:15 +0000 (15:47 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 29 Apr 2023 14:47:15 +0000 (15:47 +0100)
src/libmime/lang_detection.c
src/libmime/lang_detection_fasttext.cxx
src/libmime/lang_detection_fasttext.h

index 09591438e2fd4d13059902b2a66f94842baafbfb..211dfe48bbe3794c403f58647e9826d4a40d064f 100644 (file)
@@ -1801,88 +1801,132 @@ rspamd_language_detector_detect (struct rspamd_task *task,
        }
 
        if (!ret) {
-               if (part->utf_words->len < default_short_text_limit) {
-                       r = rs_detect_none;
-                       msg_debug_lang_det ("text is too short for trigrams detection: "
-                                          "%d words; at least %d words required",
+               unsigned ndetected = 0;
+               if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) {
+                       rspamd_fasttext_predict_result_t fasttext_predict_result;
+                       fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector,
+                               part->utf_stripped_content->data,
+                               part->utf_stripped_content->len, 4);
+
+                       ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result);
+
+                       if (ndetected > 0) {
+                               candidates = kh_init (rspamd_candidates_hash);
+                               kh_resize (rspamd_candidates_hash, candidates, ndetected);
+
+                               /* Now fill all results where probability is above threshold */
+                               float max_prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, 0);
+
+                               for (unsigned int i = 0; i < ndetected; i ++) {
+                                       float prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
+                                       if (prob > max_prob * 0.75) {
+                                               char *lang = rspamd_mempool_strdup(task->task_pool,
+                                                       rspamd_lang_detection_fasttext_get_lang(fasttext_predict_result, i));
+                                               int tmp;
+                                               khiter_t k = kh_put (rspamd_candidates_hash, candidates, lang, &tmp);
+
+                                               kh_value(candidates, k) = rspamd_mempool_alloc0(task->task_pool, sizeof(*cand));
+                                               cand = kh_value(candidates, k);
+                                               cand->lang = lang;
+                                               cand->prob = rspamd_lang_detection_fasttext_get_prob(fasttext_predict_result, i);
+                                       }
+                               }
+
+                               if (kh_size(candidates) == 1) {
+                                       r = rs_detect_single;
+                               }
+                               else if (kh_size(candidates) > 1) {
+                                       r = rs_detect_multiple;
+                               }
+                               else {
+                                       r = rs_detect_none;
+                               }
+                       }
+               }
+               if (ndetected == 0) {
+                       if (part->utf_words->len < default_short_text_limit) {
+                               r = rs_detect_none;
+                               msg_debug_lang_det ("text is too short for trigrams detection: "
+                                                                       "%d words; at least %d words required",
                                        (int)part->utf_words->len,
                                        (int)default_short_text_limit);
-                       switch (cat) {
-                       case RSPAMD_LANGUAGE_CYRILLIC:
-                               rspamd_language_detector_set_language (task, part, "ru", NULL);
-                               break;
-                       case RSPAMD_LANGUAGE_DEVANAGARI:
-                               rspamd_language_detector_set_language (task, part, "hi", NULL);
-                               break;
-                       case RSPAMD_LANGUAGE_ARAB:
-                               rspamd_language_detector_set_language (task, part, "ar", NULL);
-                               break;
-                       default:
-                       case RSPAMD_LANGUAGE_LATIN:
-                               rspamd_language_detector_set_language (task, part, "en", NULL);
-                               break;
-                       }
-                       msg_debug_lang_det ("set %s language based on symbols category",
+                               switch (cat) {
+                               case RSPAMD_LANGUAGE_CYRILLIC:
+                                       rspamd_language_detector_set_language (task, part, "ru", NULL);
+                                       break;
+                               case RSPAMD_LANGUAGE_DEVANAGARI:
+                                       rspamd_language_detector_set_language (task, part, "hi", NULL);
+                                       break;
+                               case RSPAMD_LANGUAGE_ARAB:
+                                       rspamd_language_detector_set_language (task, part, "ar", NULL);
+                                       break;
+                               default:
+                               case RSPAMD_LANGUAGE_LATIN:
+                                       rspamd_language_detector_set_language (task, part, "en", NULL);
+                                       break;
+                               }
+                               msg_debug_lang_det ("set %s language based on symbols category",
                                        part->language);
 
-                       candidates = kh_init (rspamd_candidates_hash);
-               }
-               else {
-                       candidates = kh_init (rspamd_candidates_hash);
-                       kh_resize (rspamd_candidates_hash, candidates, 32);
+                               candidates = kh_init (rspamd_candidates_hash);
+                       }
+                       else {
+                               candidates = kh_init (rspamd_candidates_hash);
+                               kh_resize (rspamd_candidates_hash, candidates, 32);
 
-                       r = rspamd_language_detector_try_ngramm (task,
+                               r = rspamd_language_detector_try_ngramm (task,
                                        default_words,
                                        d,
                                        part->utf_words,
                                        cat,
                                        candidates);
 
-                       if (r == rs_detect_none) {
-                               msg_debug_lang_det ("no trigrams found, fallback to english");
-                               rspamd_language_detector_set_language (task, part, "en", NULL);
-                       } else if (r == rs_detect_multiple) {
-                               /* Check our guess */
-
-                               mean = 0.0;
-                               std = 0.0;
-                               cand_len = 0;
-
-                               /* Check distribution */
-                               kh_foreach_value (candidates, cand, {
-                                       if (!isnan (cand->prob)) {
-                                               mean += cand->prob;
-                                               cand_len++;
-                                       }
-                               });
+                               if (r == rs_detect_none) {
+                                       msg_debug_lang_det ("no trigrams found, fallback to english");
+                                       rspamd_language_detector_set_language (task, part, "en", NULL);
+                               } else if (r == rs_detect_multiple) {
+                                       /* Check our guess */
 
-                               if (cand_len > 0) {
-                                       mean /= cand_len;
+                                       mean = 0.0;
+                                       std = 0.0;
+                                       cand_len = 0;
 
+                                       /* Check distribution */
                                        kh_foreach_value (candidates, cand, {
-                                               gdouble err;
                                                if (!isnan (cand->prob)) {
-                                                       err = cand->prob - mean;
-                                                       std += fabs (err);
+                                                       mean += cand->prob;
+                                                       cand_len++;
                                                }
                                        });
 
-                                       std /= cand_len;
-                               }
+                                       if (cand_len > 0) {
+                                               mean /= cand_len;
 
-                               msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
+                                               kh_foreach_value (candidates, cand, {
+                                                       gdouble err;
+                                                       if (!isnan (cand->prob)) {
+                                                               err = cand->prob - mean;
+                                                               std += fabs (err);
+                                                       }
+                                               });
+
+                                               std /= cand_len;
+                                       }
+
+                                       msg_debug_lang_det ("trigrams checked, %d candidates, %.3f mean, %.4f stddev",
                                                cand_len, mean, std);
 
-                               if (cand_len > 0 && std / fabs (mean) < 0.25) {
-                                       msg_debug_lang_det ("apply frequency heuristic sorting");
-                                       frequency_heuristic_applied = TRUE;
-                                       cbd.d = d;
-                                       cbd.mean = mean;
-                                       cbd.std = std;
-                                       cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;
+                                       if (cand_len > 0 && std / fabs (mean) < 0.25) {
+                                               msg_debug_lang_det ("apply frequency heuristic sorting");
+                                               frequency_heuristic_applied = TRUE;
+                                               cbd.d = d;
+                                               cbd.mean = mean;
+                                               cbd.std = std;
+                                               cbd.flags = RSPAMD_LANG_FLAG_DEFAULT;
 
-                                       if (part->nwords < default_words / 2) {
-                                               cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
+                                               if (part->nwords < default_words / 2) {
+                                                       cbd.flags |= RSPAMD_LANG_FLAG_SHORT;
+                                               }
                                        }
                                }
                        }
@@ -1909,7 +1953,9 @@ rspamd_language_detector_detect (struct rspamd_task *task,
 
                        if (result->len > 0 && !frequency_heuristic_applied) {
                                cand = g_ptr_array_index (result, 0);
-                               cand->elt->occurrences++;
+                               if (cand->elt) {
+                                       cand->elt->occurrences++;
+                               }
                                d->total_occurrences++;
                        }
 
@@ -1918,6 +1964,7 @@ rspamd_language_detector_detect (struct rspamd_task *task,
                        }
 
                        part->languages = result;
+                       part->language = ((struct rspamd_lang_detector_res *)g_ptr_array_index (result, 0))->lang;
                        ret = TRUE;
                }
                else if (part->languages == NULL) {
index 9ede47a6edc0904254ca7b849cc200b4a035f73d..eda4c2850067e458f20f5a99e83fa6a84c4aaebc 100644 (file)
@@ -72,8 +72,8 @@ public:
 
        ~fasttext_langdet() = default;
 
-
-       auto detect_language(const char *in, size_t len, int k) -> std::vector<std::pair<fasttext::real, std::string>> *
+       auto is_enabled() const -> bool { return loaded; }
+       auto detect_language(const char *in, size_t len, int k) const -> std::vector<std::pair<fasttext::real, std::string>> *
        {
                if (!loaded) {
                        return nullptr;
@@ -135,6 +135,19 @@ char *rspamd_lang_detection_fasttext_show_info(void *ud)
 #endif
 }
 
+bool rspamd_lang_detection_fasttext_is_enabled(void *ud)
+{
+#ifdef WITH_FASTTEXT
+       auto *real_model = FASTTEXT_MODEL_TO_C_API(ud);
+
+       if (real_model) {
+               return real_model->is_enabled();
+       }
+#endif
+
+       return false;
+}
+
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
                                                                                           const char *in, size_t len, int k)
 {
@@ -155,27 +168,41 @@ void rspamd_lang_detection_fasttext_destroy(void *ud)
 #endif
 }
 
+
+guint
+rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res)
+{
+#ifdef WITH_FASTTEXT
+       auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
+
+       if (real_res) {
+               return real_res->size();
+       }
+#endif
+       return 0;
+}
+
 const char *
-rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res)
+rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx)
 {
 #ifdef WITH_FASTTEXT
        auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-       if (real_res && !real_res->empty()) {
-               return real_res->front().second.c_str();
+       if (real_res && real_res->size() < idx) {
+               return real_res->at(idx).second.c_str();
        }
 #endif
        return nullptr;
 }
 
 float
-rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res)
+rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx)
 {
 #ifdef WITH_FASTTEXT
        auto *real_res = FASTTEXT_RESULT_TO_C_API(res);
 
-       if (real_res && !real_res->empty()) {
-               return real_res->front().first;
+       if (real_res && real_res->size() < idx) {
+               return real_res->at(idx).first;
        }
 #endif
        return 0.0f;
index 71e253940b3859199f7b7936247c94a567bb2dfc..2e8a9fe78c3eb68c8c8aae78560026d344e60f1c 100644 (file)
@@ -27,6 +27,13 @@ struct rspamd_config;
  */
 void* rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg);
 
+/**
+ * Check if fasttext language detector is enabled
+ * @param ud
+ * @return
+ */
+bool rspamd_lang_detection_fasttext_is_enabled(void *ud);
+
 /**
  * Show info about fasttext language detector
  * @param ud
@@ -47,19 +54,25 @@ typedef  void * rspamd_fasttext_predict_result_t;
 rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud,
                const char *in, size_t len, int k);
 
+/**
+ * Get number of languages detected
+ * @param ud
+ * @return
+ */
+guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t ud);
 /**
  * Get language from fasttext result
  * @param res
  * @return
  */
-const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res);
+const char *rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx);
 
 /**
  * Get probability from fasttext result
  * @param res
  * @return
  */
-float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res);
+float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx);
 
 /**
  * Destroy fasttext result