]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add support for fuzzy learn and unlearn from lua
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 6 Sep 2016 15:14:53 +0000 (16:14 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 6 Sep 2016 15:14:53 +0000 (16:14 +0100)
src/plugins/fuzzy_check.c

index bc04b753b3bb0dc668a4f214cedfbd2329affbf2..fca27570eb49d5623efc8e99367c3b32a07308f0 100644 (file)
@@ -157,6 +157,8 @@ gint fuzzy_check_module_config (struct rspamd_config *cfg);
 gint fuzzy_check_module_reconfig (struct rspamd_config *cfg);
 static gint fuzzy_attach_controller (struct module_ctx *ctx,
        GHashTable *commands);
+static gint fuzzy_lua_learn_handler (lua_State *L);
+static gint fuzzy_lua_unlearn_handler (lua_State *L);
 
 module_t fuzzy_check_module = {
        "fuzzy_check",
@@ -595,6 +597,8 @@ fuzzy_parse_rule (struct rspamd_config *cfg, const ucl_object_t *obj,
 gint
 fuzzy_check_module_init (struct rspamd_config *cfg, struct module_ctx **ctx)
 {
+       lua_State *L = cfg->lua_state;
+
        fuzzy_module_ctx = g_malloc0 (sizeof (struct fuzzy_ctx));
 
        fuzzy_module_ctx->fuzzy_pool = rspamd_mempool_new (rspamd_mempool_suggest_size (), NULL);
@@ -825,6 +829,25 @@ fuzzy_check_module_init (struct rspamd_config *cfg, struct module_ctx **ctx)
                        NULL,
                        0);
 
+       /* Register global methods */
+       lua_getglobal (L, "rspamd_plugins");
+
+       if (lua_type (L, -1) == LUA_TTABLE) {
+               lua_pushstring (L, "fuzzy_check");
+               lua_createtable (L, 0, 2);
+               /* Set methods */
+               lua_pushstring (L, "unlearn");
+               lua_pushcfunction (L, fuzzy_lua_unlearn_handler);
+               lua_settable (L, -3);
+               lua_pushstring (L, "learn");
+               lua_pushcfunction (L, fuzzy_lua_learn_handler);
+               lua_settable (L, -3);
+               /* Finish fuzzy_check key */
+               lua_settable (L, -3);
+       }
+
+       lua_pop (L, 1); /* Remove global function */
+
        return 0;
 }
 
@@ -1905,7 +1928,11 @@ fuzzy_controller_io_callback (gint fd, short what, void *arg)
         * XXX: please, please, change this code some day
         */
        (*session->saved)--;
-       rspamd_http_connection_unref (session->http_entry->conn);
+
+       if (session->http_entry) {
+               rspamd_http_connection_unref (session->http_entry->conn);
+       }
+
        event_del (&session->ev);
        event_del (&session->timev);
        close (session->fd);
@@ -1926,18 +1953,27 @@ cleanup:
         */
 
        if (*(session->err) != NULL) {
-               rspamd_controller_send_error (session->http_entry,
-                               (*session->err)->code, (*session->err)->message);
+               if (session->http_entry) {
+                       rspamd_controller_send_error (session->http_entry,
+                                       (*session->err)->code, (*session->err)->message);
+               }
+
                g_error_free (*session->err);
        }
        else {
                rspamd_upstream_ok (session->server);
-               rspamd_controller_send_string (session->http_entry,
+
+               if (session->http_entry) {
+                       rspamd_controller_send_string (session->http_entry,
                                "{\"success\":true}");
+               }
        }
 
        if (session->task != NULL) {
-               rspamd_task_free (session->task);
+               if (session->http_entry) {
+                       rspamd_task_free (session->task);
+               }
+
                session->task = NULL;
        }
 
@@ -1954,8 +1990,12 @@ fuzzy_controller_timer_callback (gint fd, short what, void *arg)
 
        if (session->retransmits >= fuzzy_module_ctx->retransmits) {
                rspamd_upstream_fail (session->server);
-               rspamd_controller_send_error (session->http_entry,
-                               500, "IO timeout with fuzzy storage");
+
+               if (session->http_entry) {
+                       rspamd_controller_send_error (session->http_entry,
+                                       500, "IO timeout with fuzzy storage");
+               }
+
                msg_err_task ("got IO timeout with server %s(%s), after %d retransmits",
                                rspamd_upstream_name (session->server),
                                rspamd_inet_address_to_string (session->addr),
@@ -1964,12 +2004,18 @@ fuzzy_controller_timer_callback (gint fd, short what, void *arg)
                if (*session->saved > 0 ) {
                        (*session->saved)--;
                        if (*session->saved == 0) {
-                               rspamd_task_free (session->task);
+                               if (session->http_entry) {
+                                       rspamd_task_free (session->task);
+                               }
+
                                session->task = NULL;
                        }
                }
 
-               rspamd_http_connection_unref (session->http_entry->conn);
+               if (session->http_entry) {
+                       rspamd_http_connection_unref (session->http_entry->conn);
+               }
+
                event_del (&session->ev);
                event_del (&session->timev);
                close (session->fd);
@@ -2505,6 +2551,242 @@ fuzzy_controller_handler (struct rspamd_http_connection_entry *conn_ent,
        return 0;
 }
 
+static inline gint
+fuzzy_check_send_lua_learn (struct fuzzy_rule *rule,
+       struct rspamd_task *task,
+       GPtrArray *commands,
+       gint *saved,
+       GError **err)
+{
+       struct fuzzy_learn_session *s;
+       struct upstream *selected;
+       rspamd_inet_addr_t *addr;
+       gint sock;
+       gint ret = -1;
+
+       /* Get upstream */
+
+       while ((selected = rspamd_upstream_get (rule->servers,
+                       RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) {
+               /* Create UDP socket */
+               addr = rspamd_upstream_addr (selected);
+
+               if ((sock = rspamd_inet_address_connect (addr,
+                               SOCK_DGRAM, TRUE)) == -1) {
+                       rspamd_upstream_fail (selected);
+               }
+               else {
+                       s =
+                               rspamd_mempool_alloc0 (task->task_pool,
+                                       sizeof (struct fuzzy_learn_session));
+
+                       msec_to_tv (fuzzy_module_ctx->io_timeout, &s->tv);
+                       s->task = task;
+                       s->addr = addr;
+                       s->commands = commands;
+                       s->http_entry = NULL;
+                       s->server = selected;
+                       s->saved = saved;
+                       s->fd = sock;
+                       s->err = err;
+                       s->rule = rule;
+
+                       event_set (&s->ev, sock, EV_WRITE, fuzzy_controller_io_callback, s);
+                       event_base_set (task->ev_base, &s->ev);
+                       event_add (&s->ev, NULL);
+
+                       evtimer_set (&s->timev, fuzzy_controller_timer_callback, s);
+                       event_base_set (s->task->ev_base, &s->timev);
+                       event_add (&s->timev, &s->tv);
+
+                       (*saved)++;
+                       ret = 1;
+               }
+       }
+
+       return ret;
+}
+
+static gboolean
+fuzzy_check_lua_process_learn (struct rspamd_task *task,
+               gint cmd, gint value, gint flag)
+{
+       struct fuzzy_rule *rule;
+       gboolean processed = FALSE, res = TRUE;
+       GList *cur;
+       GError **err;
+       GPtrArray *commands;
+       gint *saved, rules = 0;
+
+       saved = rspamd_mempool_alloc0 (task->task_pool, sizeof (gint));
+       err = rspamd_mempool_alloc0 (task->task_pool, sizeof (GError *));
+
+       cur = fuzzy_module_ctx->fuzzy_rules;
+
+       while (cur && res) {
+               rule = cur->data;
+
+               if (rule->read_only) {
+                       cur = g_list_next (cur);
+                       continue;
+               }
+
+               /* Check for flag */
+               if (g_hash_table_lookup (rule->mappings,
+                               GINT_TO_POINTER (flag)) == NULL) {
+                       msg_info_task ("skip rule %s as it has no flag %d defined"
+                                       " false", rule->name, flag);
+                       cur = g_list_next (cur);
+                       continue;
+               }
+
+               rules ++;
+
+               res = 0;
+               commands = fuzzy_generate_commands (task, rule, cmd, flag, value);
+               if (commands != NULL) {
+                       res = fuzzy_check_send_lua_learn (rule, task, commands,
+                                       saved, err);
+               }
+
+               if (res) {
+                       processed = TRUE;
+               }
+
+               cur = g_list_next (cur);
+       }
+
+       if (res == -1) {
+               msg_warn_task ("<%s>: cannot send fuzzy request: %s", task->message_id,
+                               strerror (errno));
+       }
+       else if (!processed) {
+               if (rules) {
+                       msg_warn_task ("<%s>: no content to generate fuzzy",
+                                       task->message_id);
+
+                       return FALSE;
+               }
+               else {
+                       msg_warn_task ("<%s>: no fuzzy rules found for flag %d",
+                                       task->message_id,
+                               flag);
+                       return FALSE;
+               }
+       }
+
+       return TRUE;
+}
+
+static gint
+fuzzy_lua_learn_handler (lua_State *L)
+{
+       struct rspamd_task *task = lua_check_task (L, 1);
+       guint flag = 0, weight = 1.0;
+       const gchar *symbol;
+
+
+       if (task) {
+               if (lua_type (L, 2) == LUA_TNUMBER) {
+                       flag = lua_tonumber (L, 2);
+               }
+               else if (lua_type (L, 2) == LUA_TSTRING) {
+                       struct fuzzy_rule *rule;
+                       GList *cur;
+                       GHashTableIter it;
+                       gpointer k, v;
+                       struct fuzzy_mapping *map;
+
+                       symbol = lua_tostring (L, 2);
+
+                       for (cur = fuzzy_module_ctx->fuzzy_rules; cur != NULL && flag == 0;
+                                       cur = g_list_next (cur)) {
+                               rule = cur->data;
+
+                               g_hash_table_iter_init (&it, rule->mappings);
+
+                               while (g_hash_table_iter_next (&it, &k, &v)) {
+                                       map = v;
+
+                                       if (g_ascii_strcasecmp (symbol, map->symbol) == 0) {
+                                               flag = map->fuzzy_flag;
+                                               break;
+                                       }
+                               }
+                       }
+               }
+
+               if (flag == 0) {
+                       return luaL_error (L, "bad flag");
+               }
+
+               if (lua_type (L, 3) == LUA_TNUMBER) {
+                       weight = lua_tonumber (L, 3);
+               }
+
+               lua_pushboolean (L,
+                               fuzzy_check_lua_process_learn (task, FUZZY_WRITE, weight, flag));
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
+static gint
+fuzzy_lua_unlearn_handler (lua_State *L)
+{
+       struct rspamd_task *task = lua_check_task (L, 1);
+       guint flag = 0, weight = 1.0;
+       const gchar *symbol;
+
+       if (task) {
+               if (lua_type (L, 2) == LUA_TNUMBER) {
+                       flag = lua_tonumber (L, 1);
+               }
+               else if (lua_type (L, 2) == LUA_TSTRING) {
+                       struct fuzzy_rule *rule;
+                       GList *cur;
+                       GHashTableIter it;
+                       gpointer k, v;
+                       struct fuzzy_mapping *map;
+
+                       for (cur = fuzzy_module_ctx->fuzzy_rules; cur != NULL && flag == 0;
+                                       cur = g_list_next (cur)) {
+                               rule = cur->data;
+
+                               g_hash_table_iter_init (&it, rule->mappings);
+
+                               while (g_hash_table_iter_next (&it, &k, &v)) {
+                                       map = v;
+
+                                       if (g_ascii_strcasecmp (symbol, map->symbol) == 0) {
+                                               flag = map->fuzzy_flag;
+                                               break;
+                                       }
+                               }
+                       }
+               }
+
+               if (flag == 0) {
+                       return luaL_error (L, "bad flag");
+               }
+
+               if (lua_type (L, 3) == LUA_TNUMBER) {
+                       weight = lua_tonumber (L, 3);
+               }
+
+               lua_pushboolean (L,
+                               fuzzy_check_lua_process_learn (task, FUZZY_DEL, weight, flag));
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
 static gboolean
 fuzzy_add_handler (struct rspamd_http_connection_entry *conn_ent,
        struct rspamd_http_message *msg, struct module_ctx *ctx)