]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Fix learning.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Jan 2016 15:08:48 +0000 (15:08 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Jan 2016 15:08:48 +0000 (15:08 +0000)
src/controller.c
src/libserver/task.c
src/libserver/task.h

index aa90fc8b6e6596870e11b040ba70aed23f8241f6..0a359a9a6d461e9fabf8f0b040edd4e27f3a582e 100644 (file)
@@ -1163,25 +1163,55 @@ rspamd_controller_learn_fin_task (void *ud)
        struct rspamd_task *task = ud;
        struct rspamd_controller_session *session;
        struct rspamd_http_connection_entry *conn_ent;
-       GError *err = NULL;
 
        conn_ent = task->fin_arg;
        session = conn_ent->ud;
 
-       if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) ==
-                       RSPAMD_STAT_PROCESS_ERROR) {
-               msg_info_session ("cannot learn <%s>: %e", task->message_id, err);
-               rspamd_controller_send_error (conn_ent, err->code, err->message);
+       if (task->err != NULL) {
+               msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err);
+               rspamd_controller_send_error (conn_ent, task->err->code,
+                               task->err->message);
 
                return TRUE;
        }
 
-       /* Successful learn */
-       msg_info_session ("<%s> learned message as %s: %s",
-               rspamd_inet_address_to_string (session->from_addr),
-               session->is_spam ? "spam" : "ham",
-               task->message_id);
-       rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+       if (RSPAMD_TASK_IS_PROCESSED (task)) {
+               /* Successful learn */
+               msg_info_session ("<%s> learned message as %s: %s",
+                               rspamd_inet_address_to_string (session->from_addr),
+                               session->is_spam ? "spam" : "ham",
+                                               task->message_id);
+               rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+               return TRUE;
+       }
+
+       if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) {
+               msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err);
+
+               if (task->err) {
+                       rspamd_controller_send_error (conn_ent, task->err->code,
+                                       task->err->message);
+               }
+               else {
+                       rspamd_controller_send_error (conn_ent, 500,
+                                                               "Internal error");
+               }
+       }
+
+       if (RSPAMD_TASK_IS_PROCESSED (task)) {
+               msg_info_session ("<%s> learned message as %s: %s",
+                               rspamd_inet_address_to_string (session->from_addr),
+                               session->is_spam ? "spam" : "ham",
+                                               task->message_id);
+               rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+               return TRUE;
+       }
+
+       /* One more iteration */
+       return FALSE;
+
+
+
 
        return TRUE;
 }
@@ -1284,6 +1314,8 @@ rspamd_controller_handle_learn_common (
                return 0;
        }
 
+       rspamd_learn_task_spam (task, is_spam, session->classifier, NULL);
+
        if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) {
                msg_warn_session ("<%s> message cannot be processed", task->message_id);
                rspamd_controller_send_error (conn_ent, task->err->code, task->err->message);
index 579cc3461f4c244aafc09a662e3f41482736fcb4..91ed48e8663c3db50cd3f1445fe2b6f143e1ebf5 100644 (file)
@@ -457,6 +457,21 @@ rspamd_task_process (struct rspamd_task *task, guint stages)
                rspamd_lua_call_post_filters (task);
                break;
 
+       case RSPAMD_TASK_STAGE_LEARN:
+       case RSPAMD_TASK_STAGE_LEARN_PRE:
+       case RSPAMD_TASK_STAGE_LEARN_POST:
+               if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM|RSPAMD_TASK_FLAG_LEARN_HAM)) {
+                       if (!rspamd_stat_learn (task,
+                                       task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+                                       task->cfg->lua_state, task->classifier,
+                                       st, &stat_error)) {
+                               msg_err_task ("learn error: %e", stat_error);
+                               task->err = stat_error;
+                               task->processed_stages |= RSPAMD_TASK_STAGE_DONE;
+                       }
+               }
+               break;
+
        case RSPAMD_TASK_STAGE_DONE:
                task->processed_stages |= RSPAMD_TASK_STAGE_DONE;
                break;
@@ -610,7 +625,16 @@ rspamd_learn_task_spam (struct rspamd_task *task,
        const gchar *classifier,
        GError **err)
 {
-       return FALSE;
+       if (is_spam) {
+               task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+       }
+       else {
+               task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+       }
+
+       task->classifier = classifier;
+
+       return TRUE;
 }
 
 static gboolean
index 901067ba4c1d158a6286e63cd149653d40b1aec5..7ede95b317f5d3bd85021acf09158a4fcbbb25e1 100644 (file)
@@ -92,6 +92,9 @@ enum rspamd_task_stage {
                RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \
                RSPAMD_TASK_STAGE_CLASSIFIERS | \
                RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
+               RSPAMD_TASK_STAGE_LEARN_PRE | \
+               RSPAMD_TASK_STAGE_LEARN | \
+               RSPAMD_TASK_STAGE_LEARN_POST | \
                RSPAMD_TASK_STAGE_DONE)
 
 #define RSPAMD_TASK_FLAG_MIME (1 << 0)
@@ -110,6 +113,8 @@ enum rspamd_task_stage {
 #define RSPAMD_TASK_FLAG_NO_STAT (1 << 13)
 #define RSPAMD_TASK_FLAG_UNLEARN (1 << 14)
 #define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15)
+#define RSPAMD_TASK_FLAG_LEARN_SPAM (1 << 16)
+#define RSPAMD_TASK_FLAG_LEARN_HAM (1 << 17)
 
 #define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP))
 #define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON))
@@ -192,6 +197,8 @@ struct rspamd_task {
        } pre_result;                                                                   /**< Result of pre-filters                                                      */
 
        ucl_object_t *settings;                                                 /**< Settings applied to task                                           */
+
+       const gchar *classifier;                                                /**< Classifier to learn (if needed)                            */
 };
 
 /**