]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add a generic lua classifier
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Oct 2016 16:43:17 +0000 (17:43 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Oct 2016 17:13:04 +0000 (18:13 +0100)
src/libserver/cfg_file.h
src/libstat/CMakeLists.txt
src/libstat/classifiers/bayes.c
src/libstat/classifiers/classifiers.h
src/libstat/classifiers/lua_classifier.c [new file with mode: 0644]
src/libstat/stat_config.c
src/libstat/stat_process.c

index 3bfeee98c94913e05f75a376714ff61f4096d482..2f671135e471ac975edf795559d1dae4e780537e 100644 (file)
@@ -140,6 +140,10 @@ struct rspamd_tokenizer_config {
  * (e.g. redis)
  */
 #define RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND (1 << 1)
+/*
+ * No backend required for classifier
+ */
+#define RSPAMD_FLAG_CLASSIFIER_NO_BACKEND (1 << 2)
 
 /**
  * Classifier config definition
index 11f48bdc0707d57bc8e11d0261d6749a27e0cf1d..0bc92061657767a33db6b1c31c350c61ef1364a2 100644 (file)
@@ -5,7 +5,8 @@ SET(LIBSTATSRC          ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c
 SET(TOKENIZERSSRC      ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c
                                        ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c)
 
-SET(CLASSIFIERSSRC     ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c)
+SET(CLASSIFIERSSRC     ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c
+                                       ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c)
 
 SET(BACKENDSSRC        ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c
                                        ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c)
index 5ebba8d56b7a705b64c66ac7282f3143bf1e6f3b..40dcdf36fa317ea862cfbef9c60a3e64ad4b4ae9 100644 (file)
@@ -185,10 +185,12 @@ bayes_normalize_prob (gdouble x)
        return a*x4 + b*x3 + c*x2 + d*xx;
 }
 
-void
+gboolean
 bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl)
 {
        cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
+
+       return TRUE;
 }
 
 gboolean
index 6bafa8507bc00e48ad098941528326597c1b2564..e30f2153aa01041b53f857fbb7266f2d9262858b 100644 (file)
@@ -16,7 +16,7 @@ struct token_node_s;
 
 struct rspamd_stat_classifier {
        char *name;
-       void (*init_func)(rspamd_mempool_t *pool,
+       gboolean (*init_func)(rspamd_mempool_t *pool,
                        struct rspamd_classifier *cl);
        gboolean (*classify_func)(struct rspamd_classifier * ctx,
                        GPtrArray *tokens,
@@ -30,7 +30,7 @@ struct rspamd_stat_classifier {
 };
 
 /* Bayes algorithm */
-void bayes_init (rspamd_mempool_t *pool,
+gboolean bayes_init (rspamd_mempool_t *pool,
                struct rspamd_classifier *);
 gboolean bayes_classify (struct rspamd_classifier *ctx,
                GPtrArray *tokens,
@@ -42,6 +42,20 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
                gboolean unlearn,
                GError **err);
 
+/* Generic lua classifier */
+gboolean lua_classifier_init (rspamd_mempool_t *pool,
+               struct rspamd_classifier *);
+gboolean lua_classifier_classify (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task);
+gboolean lua_classifier_learn_spam (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task,
+               gboolean is_spam,
+               gboolean unlearn,
+               GError **err);
+
+
 #endif
 /*
  * vi:ts=4
diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c
new file mode 100644 (file)
index 0000000..dea0f6a
--- /dev/null
@@ -0,0 +1,46 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "classifiers.h"
+#include "cfg_file.h"
+#include "stat_internal.h"
+
+gboolean
+lua_classifier_init (rspamd_mempool_t *pool,
+               struct rspamd_classifier *cl)
+{
+       cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_NO_BACKEND;
+
+       return TRUE;
+}
+gboolean
+lua_classifier_classify (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task)
+{
+       return TRUE;
+}
+
+gboolean
+lua_classifier_learn_spam (struct rspamd_classifier *ctx,
+               GPtrArray *tokens,
+               struct rspamd_task *task,
+               gboolean is_spam,
+               gboolean unlearn,
+               GError **err)
+{
+       return TRUE;
+}
index 3856fc1173c4eb8252fdad5814f1c37535cd0d9e..48f57246892d0705bcf3746510ac7e00ae6285af 100644 (file)
 #include "rspamd.h"
 #include "cfg_rcl.h"
 #include "stat_internal.h"
+#include "lua/lua_common.h"
 
 static struct rspamd_stat_ctx *stat_ctx = NULL;
 
+static struct rspamd_stat_classifier lua_classifier = {
+       .name = "lua",
+       .init_func = lua_classifier_init,
+       .classify_func = lua_classifier_classify,
+       .learn_spam_func = lua_classifier_learn_spam,
+};
+
 static struct rspamd_stat_classifier stat_classifiers[] = {
        {
                .name = "bayes",
@@ -95,15 +103,55 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
        struct rspamd_classifier *cl;
        const ucl_object_t *cache_obj = NULL, *cache_name_obj;
        const gchar *cache_name = NULL;
+       lua_State *L = cfg->lua_state;
+       guint lua_classifiers_cnt = 0, i;
 
        if (stat_ctx == NULL) {
                stat_ctx = g_slice_alloc0 (sizeof (*stat_ctx));
        }
 
+       lua_getglobal (L, "rspamd_classifiers");
+
+       if (lua_type (L, -1) == LUA_TTABLE) {
+               lua_pushnil (L);
+
+               while (lua_next (L, -1) != 0) {
+                       lua_classifiers_cnt ++;
+                       lua_pop (L, 1);
+               }
+       }
+
+       lua_pop (L, 1);
+
+       stat_ctx->classifiers_count = G_N_ELEMENTS (stat_classifiers) +
+                               lua_classifiers_cnt;
+       stat_ctx->classifiers_subrs = g_new0 (struct rspamd_stat_classifier,
+                       stat_ctx->classifiers_count);
+
+       for (i = 0; i < G_N_ELEMENTS (stat_classifiers); i ++) {
+               memcpy (&stat_ctx->classifiers_subrs[i], &stat_classifiers[i],
+                               sizeof (struct rspamd_stat_classifier));
+       }
+
+       lua_getglobal (L, "rspamd_classifiers");
+
+       if (lua_type (L, -1) == LUA_TTABLE) {
+               lua_pushnil (L);
+
+               while (lua_next (L, -1) != 0) {
+                       lua_pushvalue (L, -2);
+                       memcpy (&stat_ctx->classifiers_subrs[i], &lua_classifier,
+                                                       sizeof (struct rspamd_stat_classifier));
+                       stat_ctx->classifiers_subrs[i].name = g_strdup (lua_tostring (L, -1));
+                       i ++;
+                       lua_pop (L, 2);
+               }
+       }
+
+       lua_pop (L, 1);
        stat_ctx->backends_subrs = stat_backends;
        stat_ctx->backends_count = G_N_ELEMENTS (stat_backends);
-       stat_ctx->classifiers_subrs = stat_classifiers;
-       stat_ctx->classifiers_count = G_N_ELEMENTS (stat_classifiers);
+
        stat_ctx->tokenizers_subrs = stat_tokenizers;
        stat_ctx->tokenizers_count = G_N_ELEMENTS (stat_tokenizers);
        stat_ctx->caches_subrs = stat_caches;
@@ -120,15 +168,32 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
 
        while (cur) {
                clf = cur->data;
-               bk = rspamd_stat_get_backend (clf->backend);
+               cl = g_slice_alloc0 (sizeof (*cl));
+               cl->cfg = clf;
+               cl->ctx = stat_ctx;
+               cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint));
+               cl->subrs = rspamd_stat_get_classifier (clf->classifier);
+               g_assert (cl->subrs != NULL);
 
-               if (bk == NULL) {
-                       msg_err_config ("cannot get backend of type %s, so disable classifier"
-                                       " %s completely", clf->backend, clf->name);
+
+               if (!cl->subrs->init_func (cfg->cfg_pool, cl)) {
+                       g_slice_free1 (sizeof (*cl), cl);
+                       msg_err_config ("cannot init classifier type %s", clf->name);
                        cur = g_list_next (cur);
                        continue;
                }
 
+               if (!(clf->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+                       bk = rspamd_stat_get_backend (clf->backend);
+
+                       if (bk == NULL) {
+                               msg_err_config ("cannot get backend of type %s, so disable classifier"
+                                               " %s completely", clf->backend, clf->name);
+                               cur = g_list_next (cur);
+                               continue;
+                       }
+               }
+
                /* XXX:
                 * Here we get the first classifier tokenizer config as the only one
                 * We NO LONGER support multiple tokenizers per rspamd instance
@@ -140,14 +205,6 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
                                        clf->tokenizer, NULL);
                }
 
-               cl = g_slice_alloc0 (sizeof (*cl));
-               cl->cfg = clf;
-               cl->ctx = stat_ctx;
-               cl->statfiles_ids = g_array_new (FALSE, FALSE, sizeof (gint));
-               cl->subrs = rspamd_stat_get_classifier (clf->classifier);
-               g_assert (cl->subrs != NULL);
-               cl->subrs->init_func (cfg->cfg_pool, cl);
-
                /* Init classifier cache */
                cache_name = NULL;
 
index 6a1480ec5a733558cf3aaa370a8104c47b2a6781..228360fa69ae8093027de01ba2b1ca212722976b 100644 (file)
@@ -667,6 +667,11 @@ rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
                        continue;
                }
 
+               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+                       res = TRUE;
+                       continue;
+               }
+
                sel = cl;
 
                for (j = 0; j < cl->statfiles_ids->len; j ++) {
@@ -759,6 +764,11 @@ rspamd_stat_backends_post_learn (struct rspamd_stat_ctx *st_ctx,
                        cl->cache->learn (task, spam, cache_run);
                }
 
+               if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+                       res = TRUE;
+                       continue;
+               }
+
                for (j = 0; j < cl->statfiles_ids->len; j ++) {
                        id = g_array_index (cl->statfiles_ids, gint, j);
                        st = g_ptr_array_index (st_ctx->statfiles, id);