]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
* Add ability to learn rspamd via worker (without password) 0.3.11
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 31 Mar 2011 16:06:25 +0000 (20:06 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Thu, 31 Mar 2011 16:06:25 +0000 (20:06 +0400)
src/controller.c
src/filter.c
src/filter.h
src/main.h
src/protocol.c
src/protocol.h
src/worker.c

index 380c477910f44933140e313d92e589ff78fa27d4..9504d3b1fca0b3247be59784da4fe2db70faebb9 100644 (file)
@@ -723,8 +723,6 @@ controller_read_socket (f_str_t * in, void *arg)
 {
        struct controller_session      *session = (struct controller_session *)arg;
        struct classifier_ctx          *cls_ctx;
-       stat_file_t                    *statfile;
-       struct statfile                *st;
        gint                            len, i, r;
        gchar                           *s, **params, *cmd, out_buf[128];
        struct worker_task             *task;
@@ -733,7 +731,6 @@ controller_read_socket (f_str_t * in, void *arg)
        GTree                          *tokens = NULL;
        GError                         *err = NULL;
        f_str_t                         c;
-       double                          sum;
 
        switch (session->state) {
        case STATE_COMMAND:
@@ -799,74 +796,14 @@ controller_read_socket (f_str_t * in, void *arg)
                        }
                        return FALSE;
                }
-               if ((s = g_hash_table_lookup (session->learn_classifier->opts, "header")) != NULL) {
-                       cur = message_get_header (task->task_pool, task->message, s, FALSE);
-                       if (cur) {
-                               memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
-                       }
-               }
-               else {
-                       cur = g_list_first (task->text_parts);
-               }
-               while (cur) {
-                       if (s != NULL) {
-                               c.len = strlen (cur->data);
-                               c.begin = cur->data;
-                       }
-                       else {
-                               part = cur->data;
-                               if (part->is_empty) {
-                                       cur = g_list_next (cur);
-                                       continue;
-                               }
-                               c.begin = part->content->data;
-                               c.len = part->content->len;
-                       }
-                       if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) {
-                               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, tokenizer error" CRLF);
-                               free_task (task, FALSE);
-                               if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
-                                       return FALSE;
-                               }
-                               session->state = STATE_REPLY;
-                               return TRUE;
-                       }
-                       cur = g_list_next (cur);
-               }
-               
-               /* Handle messages without text */
-               if (tokens == NULL) {
-                       i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, no tokens can be extracted (no text data)" CRLF END);
-                       msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
-                       free_task (task, FALSE);
-                       if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
-                               return FALSE;
-                       }
-                       session->state = STATE_REPLY;
-                       return TRUE;
-               }
 
-               /* Take care of subject */
-               tokenize_subject (task, &tokens);
-
-               /* Init classifier */
-               cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier);
-               /* Get or create statfile */
-               statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier,
-                               session->learn_symbol, &st, TRUE);
-
-               if (statfile == NULL ||
-                       ! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool,
-                                                                                                                               session->learn_symbol, tokens, session->in_class, &sum,
-                                                                                                                               session->learn_multiplier, &err)) {
+               if (!learn_task (session->learn_symbol, task, &err)) {
                        if (err) {
                                i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
-                               msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message);
                                g_error_free (err);
                        }
                        else {
                                i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
-                               msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
                        }
                        free_task (task, FALSE);
                        if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
@@ -875,18 +812,12 @@ controller_read_socket (f_str_t * in, void *arg)
                        session->state = STATE_REPLY;
                        return TRUE;
                }
-               session->worker->srv->stat->messages_learned++;
 
-               maybe_write_binlog (session->learn_classifier, st, statfile, tokens);
-               msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
-                               task->message_id, session->learn_symbol, sum);
-               statfile_pool_plan_invalidate (session->worker->srv->statfile_pool, DEFAULT_STATFILE_INVALIDATE_TIME, DEFAULT_STATFILE_INVALIDATE_JITTER);
                free_task (task, FALSE);
-               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF END, sum);
+               i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
                if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
                        return FALSE;
                }
-
                session->state = STATE_REPLY;
                break;
        case STATE_WEIGHTS:
index ec7b5a5edd2ae145a0e22eaa67ff0ff919cefcb2..df8e1a9e05fadd091da12d8809995d13d54df2b6 100644 (file)
 #   include "lua/lua_common.h"
 #endif
 
+static inline                   GQuark
+filter_error_quark (void)
+{
+       return g_quark_from_static_string ("g-filter-error-quark");
+}
+
 static void
 insert_metric_result (struct worker_task *task, struct metric *metric, const gchar *symbol,
                double flag, GList * opts, gboolean single)
@@ -799,6 +805,109 @@ check_metric_action (double score, double required_score, struct metric *metric)
        }
 }
 
+gboolean
+learn_task (const gchar *statfile, struct worker_task *task, GError **err)
+{
+       GList                          *cur;
+       struct classifier_config       *cl;
+       struct classifier_ctx          *cls_ctx;
+       gchar                          *s;
+       f_str_t                         c;
+       GTree                          *tokens = NULL;
+       struct statfile                *st;
+       stat_file_t                    *stf;
+       gdouble                         sum;
+       struct mime_text_part          *part;
+
+       /* Load classifier by symbol */
+       cl = g_hash_table_lookup (task->cfg->classifiers_symbols, statfile);
+       if (cl == NULL) {
+               g_set_error (err, filter_error_quark(), 1, "Statfile %s is not configured in any classifier", statfile);
+               return FALSE;
+       }
+
+       /* If classifier has 'header' option just classify header of this type */
+       if ((s = g_hash_table_lookup (cl->opts, "header")) != NULL) {
+               cur = message_get_header (task->task_pool, task->message, s, FALSE);
+               if (cur) {
+                       memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+               }
+       }
+       else {
+               /* Classify message otherwise */
+               cur = g_list_first (task->text_parts);
+       }
+
+       /* Get tokens from each element */
+       while (cur) {
+               if (s != NULL) {
+                       c.len = strlen (cur->data);
+                       c.begin = cur->data;
+               }
+               else {
+                       part = cur->data;
+                       /* Skip empty parts */
+                       if (part->is_empty) {
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+                       c.begin = part->content->data;
+                       c.len = part->content->len;
+               }
+               /* Get tokens */
+               if (!cl->tokenizer->tokenize_func (
+                               cl->tokenizer, task->task_pool,
+                               &c, &tokens)) {
+                       g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message");
+                       return FALSE;
+               }
+               cur = g_list_next (cur);
+       }
+
+       /* Handle messages without text */
+       if (tokens == NULL) {
+               g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data");
+               msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
+               return FALSE;
+       }
+
+       /* Take care of subject */
+       tokenize_subject (task, &tokens);
+
+       /* Init classifier */
+       cls_ctx = cl->classifier->init_func (
+                       task->task_pool, cl);
+       /* Get or create statfile */
+       stf = get_statfile_by_symbol (task->worker->srv->statfile_pool,
+                       cl, statfile, &st, TRUE);
+
+       /* Learn */
+       if (stf== NULL || !cl->classifier->learn_func (
+                       cls_ctx, task->worker->srv->statfile_pool,
+                       statfile, tokens, TRUE, &sum,
+                       1.0, err)) {
+               if (*err) {
+                       msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message);
+                       return FALSE;
+               }
+               else {
+                       g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error");
+                       msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
+                       return FALSE;
+               }
+       }
+       /* Increase statistics */
+       task->worker->srv->stat->messages_learned++;
+
+       maybe_write_binlog (cl, st, stf, tokens);
+       msg_info ("learn success for message <%s>, for statfile: %s, sum weight: %.2f",
+                       task->message_id, statfile, sum);
+       statfile_pool_plan_invalidate (task->worker->srv->statfile_pool,
+                       DEFAULT_STATFILE_INVALIDATE_TIME,
+                       DEFAULT_STATFILE_INVALIDATE_JITTER);
+
+       return TRUE;
+}
 
 /* 
  * vi:ts=4 
index cea49893b1a0389108520d72b219a47c3e56a954..2c3dde4fc0e824650171dcb942fd6b6352e3bf24 100644 (file)
@@ -123,6 +123,15 @@ void make_composites (struct worker_task *task);
  */
 double factor_consolidation_func (struct worker_task *task, const gchar *metric_name, const gchar *unused);
 
+/*
+ * Learn specified statfile with message in a task
+ * @param statfile symbol of statfile
+ * @param task worker's task object
+ * @param err pointer to GError
+ * @return true if learn succeed
+ */
+gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err);
+
 gboolean check_action_str (const gchar *data, gint *result);
 const gchar *str_action_metric (enum rspamd_metric_action action);
 gint check_metric_action (double score, double required_score, struct metric *metric);
index 78b3af14bd201879023c928c059d2c1e40fc67a4..581883a6eb2fb85f340efa281b97c2558d527c3f 100644 (file)
@@ -95,7 +95,7 @@ struct rspamd_main {
 
        memory_pool_t *server_pool;                                                                     /**< server's memory pool                                                       */
        statfile_pool_t *statfile_pool;                                                         /**< shared statfiles pool                                                      */
-    GHashTable *workers;                                        /**< workers pool indexed by pid                    */
+       GHashTable *workers;                                        /**< workers pool indexed by pid                    */
 };
 
 struct counter_data {
@@ -117,9 +117,9 @@ struct save_point {
  * Union that would be used for storing sockaddrs
  */
 union sa_union {
-  struct sockaddr_storage ss;
-  struct sockaddr_in s4;
-  struct sockaddr_in6 s6;
+       struct sockaddr_storage ss;
+       struct sockaddr_in s4;
+       struct sockaddr_in6 s6;
 };
 
 /**
@@ -151,9 +151,9 @@ struct controller_session {
        GList *parts;                                                                                           /**< extracted mime parts                                                       */
        gint in_class;                                                                                          /**< positive or negative learn                                         */
        void (*other_handler)(struct controller_session *session, 
-                                                               f_str_t *in);                                   /**< other command handler to execute at the end of processing */
+                       f_str_t *in);                                   /**< other command handler to execute at the end of processing */
        void *other_data;                                                                                       /**< and its data                                                                       */
-    struct rspamd_async_session* s;                                                            /**< async session object                                                       */
+       struct rspamd_async_session* s;                                                         /**< async session object                                                       */
 };
 
 typedef void (*controller_func_t)(gchar **args, struct controller_session *session);
@@ -178,10 +178,12 @@ struct worker_task {
        enum rspamd_command cmd;                                                                        /**< command                                                                            */
        struct custom_command *custom_cmd;                                                      /**< custom command if any                                                      */      
        gint sock;                                                                                                      /**< socket descriptor                                                          */
-    gboolean is_mime;                                           /**< if this task is mime task                      */
-    gboolean is_json;                                                                                  /**< output is JSON                                                                     */
-    gboolean is_http;                                                                                  /**< output is HTTP                                                                     */
-    gboolean is_skipped;                                        /**< whether message was skipped by configuration   */
+       gboolean is_mime;                                           /**< if this task is mime task                      */
+       gboolean is_json;                                                                                       /**< output is JSON                                                                     */
+       gboolean is_http;                                                                                       /**< output is HTTP                                                                     */
+       gboolean allow_learn;                                                                           /**< allow learning                                                                     */
+       gboolean is_skipped;                                        /**< whether message was skipped by configuration   */
+
        gchar *helo;                                                                                                    /**< helo header value                                                          */
        gchar *from;                                                                                                    /**< from header value                                                          */
        gchar *queue_id;                                                                                                /**< queue id if specified                                                      */
@@ -193,9 +195,10 @@ struct worker_task {
        gchar *deliver_to;                                                                                      /**< address to deliver                                                         */
        gchar *user;                                                                                                    /**< user to deliver                                                            */
        gchar *subject;                                                                                         /**< subject (for non-mime)                                                     */
+       gchar *statfile;                                                                                        /**< statfile for learning                                                      */
        f_str_t *msg;                                                                                           /**< message buffer                                                                     */
        rspamd_io_dispatcher_t *dispatcher;                                                     /**< IO dispatcher object                                                       */
-    struct rspamd_async_session* s;                                                            /**< async session object                                                       */
+       struct rspamd_async_session* s;                                                         /**< async session object                                                       */
        gint parts_count;                                                                                       /**< mime parts count                                                           */
        GMimeMessage *message;                                                                          /**< message, parsed with GMime                                         */
        GMimeObject *parser_parent_part;                                                        /**< current parent part                                                        */
@@ -209,9 +212,9 @@ struct worker_task {
        GList *images;                                                                                          /**< list of images                                                                     */
        GList *raw_headers_list;                                                                        /**< list of raw headers                                                        */
        GHashTable *results;                                                                            /**< hash table of metric_result indexed by 
-                                                                                                                                *    metric's name                                                                     */
+        *    metric's name                                                                     */
        GHashTable *tokens;                                                                                     /**< hash table of tokens indexed by tokenizer
-                                                                                                                                *    pointer                                                                           */
+        *    pointer                                                                           */
        GList *messages;                                                                                        /**< list of messages that would be reported            */
        GHashTable *re_cache;                                                                           /**< cache for matched or not matched regexps           */
        struct config_file *cfg;                                                                        /**< pointer to config object                                           */
index 8ffaddea1c2dd0236640c5571c039355096eb5bf..ac8515004882a9e59a5ba7dc31d7a3eda97e0e53 100644 (file)
  */
 #define MSG_CMD_PROCESS "process"
 
+/*
+ * Learn specified statfile using message
+ */
+#define MSG_CMD_LEARN "learn"
+
 /*
  * spamassassin greeting:
  */
@@ -81,6 +86,7 @@
 #define NRCPT_HEADER "Recipient-Number"
 #define RCPT_HEADER "Rcpt"
 #define SUBJECT_HEADER "Subject"
+#define STATFILE_HEADER "Statfile"
 #define QUEUE_ID_HEADER "Queue-ID"
 #define ERROR_HEADER "Error"
 #define USER_HEADER "User"
@@ -198,6 +204,22 @@ parse_check_command (struct worker_task *task, gchar *token)
                        return FALSE;
                }
                break;
+       case 'l':
+       case 'L':
+               if (g_ascii_strcasecmp (token + 1, MSG_CMD_LEARN + 1) == 0) {
+                       if (task->allow_learn) {
+                               task->cmd = CMD_LEARN;
+                       }
+                       else {
+                               msg_info ("learning is disabled");
+                               return FALSE;
+                       }
+               }
+               else {
+                       debug_task ("bad command: %s", token);
+                       return FALSE;
+               }
+               break;
        default:
                cur = custom_commands;
                while (cur) {
@@ -306,8 +328,8 @@ parse_http_command (struct worker_task *task, f_str_t * line)
                        }
                        else {
                                /* Copy command */
-                               cmd = memory_pool_alloc (task->task_pool, p - c);
-                               rspamd_strlcpy (cmd, c, p - c);
+                               cmd = memory_pool_alloc (task->task_pool, p - c + 1);
+                               rspamd_strlcpy (cmd, c, p - c + 1);
                                /* Skip the first '/' */
                                if (*cmd == '/') {
                                        cmd ++;
@@ -379,8 +401,22 @@ parse_header (struct worker_task *task, f_str_t * line)
                }
                else {
                        if (task->content_length > 0) {
-                               rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
-                               task->state = READ_MESSAGE;
+                               if (task->cmd == CMD_LEARN) {
+                                       if (task->statfile != NULL) {
+                                               rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
+                                               task->state = READ_MESSAGE;
+                                       }
+                                       else {
+                                               task->last_error = "Unknown statfile";
+                                               task->error_code = RSPAMD_STATFILE_ERROR;
+                                               task->state = WRITE_ERROR;
+                                               return FALSE;
+                                       }
+                               }
+                               else {
+                                       rspamd_set_dispatcher_policy (task->dispatcher, BUFFER_CHARACTER, task->content_length);
+                                       task->state = READ_MESSAGE;
+                               }
                        }
                        else {
                                task->last_error = "Unknown content length";
@@ -528,6 +564,9 @@ parse_header (struct worker_task *task, f_str_t * line)
                if (g_ascii_strncasecmp (headern, SUBJECT_HEADER, sizeof (SUBJECT_HEADER) - 1) == 0) {
                        task->subject = memory_pool_fstrdup (task->task_pool, line);
                }
+               else if (g_ascii_strncasecmp (headern, STATFILE_HEADER, sizeof (STATFILE_HEADER) - 1) == 0) {
+                       task->statfile = memory_pool_fstrdup (task->task_pool, line);
+               }
                else {
                        return FALSE;
                }
@@ -1433,7 +1472,7 @@ write_reply (struct worker_task *task)
                /* Write error message and error code to reply */
                if (task->is_http) {
                        r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 400 Bad request" CRLF
-                                       "Connection: close" CRLF CRLF);
+                                       "Connection: close" CRLF CRLF "Error: %d - %s" CRLF, task->error_code, task->last_error);
                }
                else {
                        if (task->proto == SPAMC_PROTO) {
@@ -1471,6 +1510,19 @@ write_reply (struct worker_task *task)
                                        (task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER, rspamc_proto_str (task->proto_ver));
                        return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
                        break;
+               case CMD_LEARN:
+                       if (task->is_http) {
+                               r = rspamd_snprintf (outbuf, sizeof (outbuf), "HTTP/1.0 200 Ok" CRLF
+                                                                       "Connection: close" CRLF CRLF "%s" CRLF, task->last_error);
+                       }
+                       else {
+                               r = rspamd_snprintf (outbuf, sizeof (outbuf), "%s/%s 0 LEARN" CRLF CRLF "%s" CRLF,
+                                               (task->proto == SPAMC_PROTO) ? SPAMD_REPLY_BANNER : RSPAMD_REPLY_BANNER,
+                                               rspamc_proto_str (task->proto_ver),
+                                               task->last_error);
+                       }
+                       return rspamd_dispatcher_write (task->dispatcher, outbuf, r, FALSE, FALSE);
+                       break;
                case CMD_OTHER:
                        return task->custom_cmd->func (task);
                }
index a15530a7b31d6f69e014c0b87295743fcc6c387f..de6d0ea03768e40f59b0e980920fe8dd8dbd49d1 100644 (file)
@@ -13,6 +13,7 @@
 #define RSPAMD_NETWORK_ERROR 2
 #define RSPAMD_PROTOCOL_ERROR 3
 #define RSPAMD_LENGTH_ERROR 4
+#define RSPAMD_STATFILE_ERROR 5
 
 #define RSPAMC_PROTO_1_0 "1.0"
 #define RSPAMC_PROTO_1_1 "1.1"
@@ -44,6 +45,7 @@ enum rspamd_command {
        CMD_SKIP,
        CMD_PING,
        CMD_PROCESS,
+       CMD_LEARN,
        CMD_OTHER,
 };
 
index 57ad4ecf133b012dc1a6aceaf380d3c3b33da66d..a9b05d64e8f62e70d5a129e907ec3e8d30994f37 100644 (file)
@@ -89,6 +89,8 @@ struct rspamd_worker_ctx {
        gboolean                        is_http;
        /* JSON output                                                                  */
        gboolean                        is_json;
+       /* Allow learning throught worker                               */
+       gboolean                        allow_learn;
        GList                          *custom_filters;
        /* DNS resolver */
        struct rspamd_dns_resolver     *resolver;
@@ -318,6 +320,7 @@ read_socket (f_str_t * in, void *arg)
        struct worker_task             *task = (struct worker_task *) arg;
        struct rspamd_worker_ctx       *ctx;
        ssize_t                         r;
+       GError                         *err = NULL;
 
        ctx = task->worker->ctx;
        switch (task->state) {
@@ -332,8 +335,10 @@ read_socket (f_str_t * in, void *arg)
                }
                else {
                        if (!read_rspamd_input_line (task, in)) {
-                               task->last_error = "Read error";
-                               task->error_code = RSPAMD_NETWORK_ERROR;
+                               if (!task->last_error) {
+                                       task->last_error = "Read error";
+                                       task->error_code = RSPAMD_NETWORK_ERROR;
+                               }
                                task->state = WRITE_ERROR;
                        }
                }
@@ -359,22 +364,38 @@ read_socket (f_str_t * in, void *arg)
                        task->state = WRITE_REPLY;
                        return write_socket (task);
                }
-               r = process_filters (task);
-               if (r == -1) {
-                       task->last_error = "Filter processing error";
-                       task->error_code = RSPAMD_FILTER_ERROR;
-                       task->state = WRITE_ERROR;
+               else if (task->cmd == CMD_LEARN) {
+                       if (!learn_task (task->statfile, task, &err)) {
+                               task->last_error = memory_pool_strdup (task->task_pool, err->message);
+                               task->error_code = err->code;
+                               g_error_free (err);
+                               task->state = WRITE_ERROR;
+                       }
+                       else {
+                               task->last_error = "learn ok";
+                               task->error_code = 0;
+                               task->state = WRITE_REPLY;
+                       }
                        return write_socket (task);
                }
-               else if (r == 0) {
-                       task->state = WAIT_FILTER;
-                       rspamd_dispatcher_pause (task->dispatcher);
-               }
                else {
-                       process_statfiles (task);
-                       lua_call_post_filters (task);
-                       task->state = WRITE_REPLY;
-                       return write_socket (task);
+                       r = process_filters (task);
+                       if (r == -1) {
+                               task->last_error = "Filter processing error";
+                               task->error_code = RSPAMD_FILTER_ERROR;
+                               task->state = WRITE_ERROR;
+                               return write_socket (task);
+                       }
+                       else if (r == 0) {
+                               task->state = WAIT_FILTER;
+                               rspamd_dispatcher_pause (task->dispatcher);
+                       }
+                       else {
+                               process_statfiles (task);
+                               lua_call_post_filters (task);
+                               task->state = WRITE_REPLY;
+                               return write_socket (task);
+                       }
                }
                break;
        case WRITE_REPLY:
@@ -515,9 +536,8 @@ construct_task (struct rspamd_worker *worker)
 {
        struct worker_task             *new_task;
 
-       new_task = g_malloc (sizeof (struct worker_task));
+       new_task = g_malloc0 (sizeof (struct worker_task));
 
-       bzero (new_task, sizeof (struct worker_task));
        new_task->worker = worker;
        new_task->state = READ_COMMAND;
        new_task->cfg = worker->srv->cfg;
@@ -605,10 +625,12 @@ accept_socket (gint fd, short what, void *arg)
                                sizeof (struct in_addr));
        }
 
+       /* Copy some variables */
        new_task->sock = nfd;
        new_task->is_mime = ctx->is_mime;
        new_task->is_json = ctx->is_json;
        new_task->is_http = ctx->is_http;
+       new_task->allow_learn = ctx->allow_learn;
 
        worker->srv->stat->connections_count++;
        new_task->resolver = ctx->resolver;
@@ -750,6 +772,7 @@ init_worker (void)
        register_worker_opt (TYPE_WORKER, "mime", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_mime));
        register_worker_opt (TYPE_WORKER, "http", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_http));
        register_worker_opt (TYPE_WORKER, "json", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, is_json));
+       register_worker_opt (TYPE_WORKER, "allow_learn", xml_handle_boolean, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, allow_learn));
        register_worker_opt (TYPE_WORKER, "timeout", xml_handle_seconds, ctx, G_STRUCT_OFFSET (struct rspamd_worker_ctx, timeout));
 
        return ctx;