]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Restore multiple classifiers support
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 23 Nov 2015 18:36:41 +0000 (18:36 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 23 Nov 2015 18:36:41 +0000 (18:36 +0000)
src/controller.c
src/libserver/cfg_rcl.c
src/libserver/task.c
src/libserver/task.h
src/libserver/worker_util.h
src/libstat/stat_api.h
src/libstat/stat_process.c
src/lua/lua_task.c

index c0783ce9db18e4a9c60ce1a99bc18fa331dba028..bb494108f385af404fd3f19982a05ee21ab77e16 100644 (file)
@@ -1168,7 +1168,7 @@ rspamd_controller_learn_fin_task (void *ud)
        conn_ent = task->fin_arg;
        session = conn_ent->ud;
 
-       if (rspamd_learn_task_spam (session->cl, task, session->is_spam, &err) ==
+       if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) ==
                        RSPAMD_STAT_PROCESS_ERROR) {
                msg_info_session ("cannot learn <%s>: %e", task->message_id, err);
                rspamd_controller_send_error (conn_ent, err->code, err->message);
@@ -1238,8 +1238,8 @@ rspamd_controller_handle_learn_common (
 {
        struct rspamd_controller_session *session = conn_ent->ud;
        struct rspamd_controller_worker_ctx *ctx;
-       struct rspamd_classifier_config *cl;
        struct rspamd_task *task;
+       const rspamd_ftok_t *cl_header;
 
        ctx = session->ctx;
 
@@ -1255,13 +1255,6 @@ rspamd_controller_handle_learn_common (
                return 0;
        }
 
-       /* XXX: now work with only bayes */
-       cl = rspamd_config_find_classifier (ctx->cfg, "bayes");
-       if (cl == NULL) {
-               rspamd_controller_send_error (conn_ent, 400, "Classifier not found");
-               return 0;
-       }
-
        task = rspamd_task_new (session->ctx->worker, session->cfg);
 
        task->resolver = ctx->resolver;
@@ -1277,8 +1270,14 @@ rspamd_controller_handle_learn_common (
        task->http_conn = rspamd_http_connection_ref (conn_ent->conn);;
        task->sock = -1;
        session->task = task;
-       session->cl = cl;
 
+       cl_header = rspamd_http_message_find_header (msg, "classifier");
+       if (cl_header) {
+               session->classifier = rspamd_mempool_ftokdup (session->pool, cl_header);
+       }
+       else {
+               session->classifier = NULL;
+       }
 
        if (!rspamd_task_load_message (task, msg, msg->body_buf.begin, msg->body_buf.len)) {
                rspamd_controller_send_error (conn_ent, task->err->code, task->err->message);
index dc74ece310dd6ab711ec8eb7cae5c87af77520af..e8e0818ca4501fabf5eff225e5cfc0a3a3ec5a2a 100644 (file)
@@ -1652,6 +1652,11 @@ rspamd_rcl_config_init (void)
                        rspamd_rcl_parse_struct_string,
                        G_STRUCT_OFFSET (struct rspamd_classifier_config, backend),
                        0);
+       rspamd_rcl_add_default_handler (sub,
+                       "name",
+                       rspamd_rcl_parse_struct_string,
+                       G_STRUCT_OFFSET (struct rspamd_classifier_config, name),
+                       0);
 
        /*
         * Statfile defaults
index c4ae1762c808dd881e90c49c732dda0861fbc415..7d34e830b599fc243f10aeae788bf491a941fd05 100644 (file)
@@ -645,12 +645,16 @@ rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re)
 }
 
 gboolean
-rspamd_learn_task_spam (struct rspamd_classifier_config *cl,
-       struct rspamd_task *task,
+rspamd_learn_task_spam (struct rspamd_task *task,
        gboolean is_spam,
+       const gchar *classifier,
        GError **err)
 {
-       return rspamd_stat_learn (task, is_spam, task->cfg->lua_state, err);
+       return rspamd_stat_learn (task,
+                       is_spam,
+                       task->cfg->lua_state,
+                       classifier,
+                       err);
 }
 
 static gboolean
index b29bcebf63d7493204ce48a14ced49c5b1c84175..49357e00b5e31625033582e5555149fa080f580c 100644 (file)
@@ -260,14 +260,14 @@ guint rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re);
 
 /**
  * Learn specified statfile with message in a task
- * @param statfile symbol of statfile
  * @param task worker's task object
+ * @param classifier classifier to learn (or NULL to learn all)
  * @param err pointer to GError
  * @return true if learn succeed
  */
-gboolean rspamd_learn_task_spam (struct rspamd_classifier_config *cl,
-       struct rspamd_task *task,
+gboolean rspamd_learn_task_spam (struct rspamd_task *task,
        gboolean is_spam,
+       const gchar *classifier,
        GError **err);
 
 /**
index 21c86f92e0b869354b599574b339474f1e808a04..837e6ac331f0636564f5f2f88b3e8e59c88039ef 100644 (file)
@@ -84,7 +84,7 @@ struct rspamd_controller_session {
        struct rspamd_worker *wrk;
        rspamd_mempool_t *pool;
        struct rspamd_task *task;
-       struct rspamd_classifier_config *cl;
+       gchar *classifier;
        rspamd_inet_addr_t *from_addr;
        struct rspamd_config *cfg;
        gboolean is_spam;
index 49335400733d1b88aa4a9cd37598decf61361f9d..ba5dc4a409b40cee8d39edfac46f1a7c0e22feca 100644 (file)
@@ -58,6 +58,8 @@ void rspamd_stat_close (void);
 /**
  * Classify the task specified and insert symbols if needed
  * @param task
+ * @param L lua state
+ * @param err error returned
  * @return TRUE if task has been classified
  */
 rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
@@ -68,10 +70,13 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
  * Learn task as spam or ham, task must be processed prior to this call
  * @param task task to learn
  * @param spam if TRUE learn spam, otherwise learn ham
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param err error returned
  * @return TRUE if task has been learned
  */
 rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
-               gboolean spam, lua_State *L,
+               gboolean spam, lua_State *L, const gchar *classifier,
                GError **err);
 
 /**
index b19663893ac2e1a4f7afc67ce4926de6efef140c..952330b4937e4485a31499ba2ca267a78cc88343 100644 (file)
@@ -353,7 +353,11 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
 static GList*
 rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                struct rspamd_task *task,
-               lua_State *L, gint op, gboolean spam, GError **err)
+               lua_State *L,
+               gint op,
+               gboolean spam,
+               const gchar *classifier,
+               GError **err)
 {
        struct rspamd_classifier_config *clcf;
        struct rspamd_statfile_config *stcf;
@@ -373,6 +377,15 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                clcf = (struct rspamd_classifier_config *)cur->data;
                st_list = NULL;
 
+               if (classifier != NULL &&
+                                       (clcf->name == NULL || strcmp (clcf->name, classifier) != 0)) {
+                       /* Skip this classifier */
+                       msg_debug_task ("skip classifier %s, as we are requested to check %s only",
+                                       clcf->name, classifier);
+                       cur = g_list_next (cur);
+                       continue;
+               }
+
                if (clcf->pre_callbacks != NULL) {
                        st_list = rspamd_lua_call_cls_pre_callbacks (clcf, task, FALSE,
                                        FALSE, L);
@@ -518,6 +531,11 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                g_tree_foreach (cbdata.tok->tokens, preprocess_init_stat_token,
                                &cbdata);
        }
+       else if (classifier != NULL) {
+               /* We likely cannot find any classifier with this name */
+               g_set_error (err, rspamd_stat_quark (), 404,
+                               "cannot find classifier %s", classifier);
+       }
 
        return cl_runtimes;
 }
@@ -538,7 +556,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
 
        /* Initialize classifiers and statfiles runtime */
        if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
-                       RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) {
+                       RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) {
                return RSPAMD_STAT_PROCESS_OK;
        }
 
@@ -659,7 +677,10 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
 }
 
 rspamd_stat_result_t
-rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
+rspamd_stat_learn (struct rspamd_task *task,
+               gboolean spam,
+               lua_State *L,
+               const gchar *classifier,
                GError **err)
 {
        struct rspamd_stat_ctx *st_ctx;
@@ -669,7 +690,8 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
        struct preprocess_cb_data cbdata;
        GList *cl_runtimes;
        GList *cur, *curst;
-       gboolean ret = RSPAMD_STAT_PROCESS_ERROR, unlearn = FALSE;
+       gboolean unlearn = FALSE;
+       rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR;
        gulong nrev;
        rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
        guint i;
@@ -698,8 +720,13 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
        }
 
        /* Initialize classifiers and statfiles runtime */
-       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
-                       unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) {
+       if ((cl_runtimes = rspamd_stat_preprocess (st_ctx,
+                       task,
+                       L,
+                       unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP,
+                       spam,
+                       classifier,
+                       err)) == NULL) {
                return RSPAMD_STAT_PROCESS_ERROR;
        }
 
index d9840275f6cf0747a4a6c5bd355f2af309f41920..0e43b7ab55f7b98e8e0ce41c3eae808d7f4a3328 100644 (file)
@@ -1945,8 +1945,7 @@ lua_task_learn (lua_State *L)
 {
        struct rspamd_task *task = lua_check_task (L, 1);
        gboolean is_spam = FALSE;
-       const gchar *clname;
-       struct rspamd_classifier_config *cl;
+       const gchar *clname = NULL;
        GError *err = NULL;
        int ret = 1;
 
@@ -1954,29 +1953,16 @@ lua_task_learn (lua_State *L)
        if (lua_gettop (L) > 2) {
                clname = luaL_checkstring (L, 3);
        }
-       else {
-               clname = "bayes";
-       }
-
-       cl = rspamd_config_find_classifier (task->cfg, clname);
 
-       if (cl == NULL) {
-               msg_warn_task ("classifier %s is not found", clname);
+       if (!rspamd_learn_task_spam (task, is_spam, clname, &err)) {
                lua_pushboolean (L, FALSE);
-               lua_pushstring (L, "classifier not found");
-               ret = 2;
+               if (err != NULL) {
+                       lua_pushstring (L, err->message);
+                       ret = 2;
+               }
        }
        else {
-               if (!rspamd_learn_task_spam (cl, task, is_spam, &err)) {
-                       lua_pushboolean (L, FALSE);
-                       if (err != NULL) {
-                               lua_pushstring (L, err->message);
-                               ret = 2;
-                       }
-               }
-               else {
-                       lua_pushboolean (L, TRUE);
-               }
+               lua_pushboolean (L, TRUE);
        }
 
        return ret;