]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Add lua_fann module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 21 Dec 2015 14:46:13 +0000 (14:46 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 21 Dec 2015 14:46:13 +0000 (14:46 +0000)
doc/Makefile
src/CMakeLists.txt
src/lua/CMakeLists.txt
src/lua/lua_common.c
src/lua/lua_common.h
src/lua/lua_fann.c [new file with mode: 0644]

index f24fc25516e610483b53ab315d2479b8f172959f..02b3258f81b5f0bdcf1b35a0b0fc5fdde16a4612 100644 (file)
@@ -8,15 +8,15 @@ all: man
 man: rspamd.8 rspamc.1 rspamadm.1
 
 rspamd.8: rspamd.8.md
-       $(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md 
+       $(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md
 rspamc.1: rspamc.1.md
        $(PANDOC) -s -f markdown -t man -o rspamc.1 rspamc.1.md
 rspamadm.1: rspamadm.1.md
        $(PANDOC) -s -f markdown -t man -o rspamadm.1 rspamadm.1.md
-       
+
 lua-doc: lua_regexp lua_ip lua_config lua_task lua_ucl lua_http lua_trie \
        lua_dns lua_redis lua_upstream lua_expression lua_mimepart lua_logger lua_url \
-       lua_tcp lua_mempool lua_html lua_util
+       lua_tcp lua_mempool lua_html lua_util lua_fann
 
 lua_regexp: ../src/lua/lua_regexp.c
        $(LUADOC) < ../src/lua/lua_regexp.c > markdown/lua/regexp.md
@@ -53,4 +53,6 @@ lua_mempool: ../src/lua/lua_mempool.c
 lua_html: ../src/lua/lua_html.c
        $(LUADOC) < ../src/lua/lua_html.c > markdown/lua/html.md
 lua_util: ../src/lua/lua_util.c
-       $(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md
\ No newline at end of file
+       $(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md
+lua_fann: ../src/lua/lua_fann.c
+       $(LUADOC) < ../src/lua/lua_fann.c > markdown/lua/fann.md
index 385de9eb853d108a19127724c1c8a0e8bcc153c3..a37256c4137ea6c02ccb6030623d9b4871faf277 100644 (file)
@@ -90,7 +90,7 @@ SET(PLUGINSSRC        plugins/surbl.c
                                plugins/fuzzy_check.c
                                plugins/spf.c
                                plugins/dkim_check.c
-                               libserver/rspamd_control.c)
+                               libserver/rspamd_control.c lua/lua_fann.c)
 
 SET(MODULES_LIST surbl regexp chartable fuzzy_check spf dkim)
 SET(WORKERS_LIST normal controller smtp_proxy fuzzy lua http_proxy)
index ad526d53412547653b5d420e174bc58338a98b5b..4d74e775270795da92b9be56af14d4f5fde2e902 100644 (file)
@@ -23,6 +23,7 @@ SET(LUASRC                      ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_url.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_util.c
                                          ${CMAKE_CURRENT_SOURCE_DIR}/lua_tcp.c
-                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c)
+                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c
+                                         ${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c)
 
-SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file
+SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
index 0cbec0408a13661a5521393c177710845558e874..e475ace8327ba60d2cad7be31e629a8c5dcad18e 100644 (file)
@@ -247,6 +247,7 @@ rspamd_lua_init ()
        luaopen_util (L);
        luaopen_tcp (L);
        luaopen_html (L);
+       luaopen_fann (L);
 
        rspamd_lua_add_preload (L, "ucl", luaopen_ucl);
 
@@ -944,3 +945,21 @@ rspamd_lua_traceback (lua_State *L)
 
        return 1;
 }
+
+guint
+rspamd_lua_table_size (lua_State *L, gint tbl_pos)
+{
+       guint tbl_size = 0;
+
+       if (!lua_istable (L, tbl_pos)) {
+               return 0;
+       }
+
+#if LUA_VERSION_NUM >= 502
+       tbl_size = lua_rawlen (L, tbl_pos);
+#else
+       tbl_size = lua_objlen (L, tbl_pos);
+#endif
+
+       return tbl_size;
+}
index f51aee73189e07aee3f3ec260e0c4cfe459d5ec8..b41d338111aa1964b0f91d9bfe40bde997b41b4b 100644 (file)
@@ -222,6 +222,7 @@ void luaopen_text (lua_State *L);
 void luaopen_util (lua_State * L);
 void luaopen_tcp (lua_State * L);
 void luaopen_html (lua_State * L);
+void luaopen_fann (lua_State *L);
 
 gint rspamd_lua_call_filter (const gchar *function, struct rspamd_task *task);
 gint rspamd_lua_call_chain_filter (const gchar *function,
@@ -289,5 +290,10 @@ gboolean rspamd_lua_parse_table_arguments (lua_State *L, gint pos,
 
 
 gint rspamd_lua_traceback (lua_State *L);
+
+/**
+ * Returns size of table at position `tbl_pos`
+ */
+guint rspamd_lua_table_size (lua_State *L, gint tbl_pos);
 #endif /* WITH_LUA */
 #endif /* RSPAMD_LUA_H */
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
new file mode 100644 (file)
index 0000000..90c037d
--- /dev/null
@@ -0,0 +1,435 @@
+/*
+ * Copyright (c) 2015, Vsevolod Stakhov
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *      * Redistributions of source code must retain the above copyright
+ *        notice, this list of conditions and the following disclaimer.
+ *      * Redistributions in binary form must reproduce the above copyright
+ *        notice, this list of conditions and the following disclaimer in the
+ *        documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "lua_common.h"
+
+#ifdef WITH_FANN
+#include <fann.h>
+#endif
+
+/***
+ * @module rspamd_fann
+ * This module enables [fann](http://libfann.github.io) interaction in rspamd
+ * Please note, that this module works merely if you have `ENABLE_FANN=ON` option
+ * definition when building rspamd
+ */
+
+/*
+ * Fann functions
+ */
+LUA_FUNCTION_DEF (fann, is_enabled);
+LUA_FUNCTION_DEF (fann, create);
+LUA_FUNCTION_DEF (fann, load);
+
+/*
+ * Fann methods
+ */
+LUA_FUNCTION_DEF (fann, train);
+LUA_FUNCTION_DEF (fann, test);
+LUA_FUNCTION_DEF (fann, save);
+LUA_FUNCTION_DEF (fann, get_inputs);
+LUA_FUNCTION_DEF (fann, get_outputs);
+LUA_FUNCTION_DEF (fann, dtor);
+
+static const struct luaL_reg fannlib_f[] = {
+               LUA_INTERFACE_DEF (fann, is_enabled),
+               LUA_INTERFACE_DEF (fann, create),
+               LUA_INTERFACE_DEF (fann, load),
+               {NULL, NULL}
+};
+
+static const struct luaL_reg fannlib_m[] = {
+               LUA_INTERFACE_DEF (fann, train),
+               LUA_INTERFACE_DEF (fann, test),
+               LUA_INTERFACE_DEF (fann, save),
+               LUA_INTERFACE_DEF (fann, get_inputs),
+               LUA_INTERFACE_DEF (fann, get_outputs),
+               {"__gc", lua_fann_dtor},
+               {"__tostring", rspamd_lua_class_tostring},
+               {NULL, NULL}
+};
+
+#ifdef WITH_FANN
+struct fann *
+rspamd_lua_check_fann (lua_State *L, gint pos)
+{
+       void *ud = luaL_checkudata (L, pos, "rspamd{fann}");
+       luaL_argcheck (L, ud != NULL, pos, "'fann' expected");
+       return ud ? *((struct fann **) ud) : NULL;
+}
+#endif
+
+/***
+ * @function rspamd_fann.is_enabled()
+ * Checks if fann is enabled for this rspamd build
+ * @return {boolean} true if fann is enabled
+ */
+static gint
+lua_fann_is_enabled (lua_State *L)
+{
+#ifdef WITH_FANN
+       lua_pushboolean (L, true);
+#else
+       lua_pushboolean (L, false);
+#endif
+       return 1;
+}
+
+/***
+ * @function rspamd_fann.create(nlayers, [layer1, ... layern])
+ * Creates new neural network with `nlayers` that contains `layer1`...`layern`
+ * neurons in each layer
+ * @param {number} nlayers number of layers
+ * @param {number} layerI number of neurons in each layer
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_create (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f, **pfann;
+       guint nlayers, *layers, i;
+
+       nlayers = luaL_checknumber (L, 1);
+
+       if (nlayers > 0) {
+               layers = g_malloc (nlayers * sizeof (layers[0]));
+
+               for (i = 0; i < nlayers; i ++) {
+                       layers[i] = luaL_checknumber (L, i + 2);
+               }
+
+               f = fann_create_standard_array (nlayers, layers);
+
+               if (f != NULL) {
+                       pfann = lua_newuserdata (L, sizeof (gpointer));
+                       *pfann = f;
+                       rspamd_lua_setclass (L, "rspamd{fann}", -1);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+/***
+ * @function rspamd_fann.load(file)
+ * Loads neural network from the file
+ * @param {string} file filename where fann is stored
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_load (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f, **pfann;
+       const gchar *fname;
+
+       fname = luaL_checkstring (L, 1);
+
+       if (fname != NULL) {
+               f = fann_create_from_file (fname);
+
+               if (f != NULL) {
+                       pfann = lua_newuserdata (L, sizeof (gpointer));
+                       *pfann = f;
+                       rspamd_lua_setclass (L, "rspamd{fann}", -1);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+
+/**
+ * @method rspamd_fann:train(inputs, outputs)
+ * Trains neural network with samples. Inputs and outputs should be tables of
+ * equal size, each row in table should be N inputs and M outputs, e.g.
+ *     {0, 1, 1} -> {0}
+ *     {1, 0, 0} -> {1}
+ * @param {table/table} inputs input samples
+ * @param {table/table} outputs output samples
+ * @return {number} number of samples learned
+ */
+static gint
+lua_fann_train (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       guint ninputs, noutputs, i, j, cur_len;
+       float *cur_input, *cur_output;
+       gint ret = 0;
+
+       if (f != NULL) {
+               /* First check sanity, call for table.getn for that */
+               ninputs = rspamd_lua_table_size (L, 2);
+               noutputs = rspamd_lua_table_size (L, 3);
+
+               if (ninputs != noutputs) {
+                       msg_err ("bad number of inputs(%d) and output(%d) args for train",
+                                       ninputs, noutputs);
+               }
+               else {
+                       for (i = 0; i < ninputs; i ++) {
+                               /* Push table with inputs */
+                               lua_rawgeti (L, 2, i + 1);
+
+                               cur_len = rspamd_lua_table_size (L, -1);
+
+                               if (cur_len != fann_get_num_input (f)) {
+                                       msg_err (
+                                                       "bad number of input samples: %d, %d expected",
+                                                       cur_len,
+                                                       fann_get_num_input (f));
+                                       lua_pop (L, 1);
+                                       continue;
+                               }
+
+                               cur_input = g_malloc (cur_len * sizeof (gint));
+
+                               for (j = 0; j < cur_len; j ++) {
+                                       lua_rawgeti (L, -1, j + 1);
+                                       cur_input[i] = lua_tonumber (L, -1);
+                                       lua_pop (L, 1);
+                               }
+
+                               lua_pop (L, 1); /* Inputs table */
+
+                               /* Push table with outputs */
+                               lua_rawgeti (L, 3, i + 1);
+
+                               cur_len = rspamd_lua_table_size (L, -1);
+
+                               if (cur_len != fann_get_num_output (f)) {
+                                       msg_err (
+                                                       "bad number of output samples: %d, %d expected",
+                                                       cur_len,
+                                                       fann_get_num_output (f));
+                                       lua_pop (L, 1);
+                                       g_free (cur_input);
+                                       continue;
+                               }
+
+                               cur_output = g_malloc (cur_len * sizeof (gint));
+
+                               for (j = 0; j < cur_len; j++) {
+                                       lua_rawgeti (L, -1, j + 1);
+                                       cur_output[i] = lua_tonumber (L, -1);
+                                       lua_pop (L, 1);
+                               }
+
+                               lua_pop (L, 1); /* Outputs table */
+
+                               fann_train (f, cur_input, cur_output);
+                               g_free (cur_input);
+                               g_free (cur_output);
+                               ret ++;
+                       }
+               }
+       }
+
+       lua_pushnumber (L, ret);
+
+       return 1;
+#endif
+}
+
+/**
+ * @method rspamd_fann:test(inputs)
+ * Tests neural network with samples. Inputs is a single sample of input data.
+ * The function returns table of results, e.g.:
+ *     {0, 1, 1} -> {0}
+ * @param {table} inputs input sample
+ * @return {table/number} outputs values
+ */
+static gint
+lua_fann_test (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       guint ninputs, noutputs, i;
+       float *cur_input, *cur_output;
+
+       if (f != NULL) {
+               /* First check sanity, call for table.getn for that */
+               ninputs = rspamd_lua_table_size (L, 2);
+               cur_input = g_malloc (ninputs * sizeof (gint));
+
+               for (i = 0; i < ninputs; i++) {
+                       lua_rawgeti (L, 2, i + 1);
+                       cur_input[i] = lua_tonumber (L, -1);
+                       lua_pop (L, 1);
+               }
+
+               cur_output = fann_run (f, cur_input);
+               noutputs = fann_get_num_output (f);
+               lua_createtable (L, noutputs, 0);
+
+               for (i = 0; i < noutputs; i ++) {
+                       lua_pushnumber (L, cur_output[i]);
+                       lua_rawseti (L, -2, i + 1);
+               }
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_inputs()
+ * Returns number of inputs for neural network
+ * @return {number} number of inputs
+ */
+static gint
+lua_fann_get_inputs (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+
+       if (f != NULL) {
+               lua_pushnumber (L, fann_get_num_input (f));
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_outputs()
+ * Returns number of outputs for neural network
+ * @return {number} number of outputs
+ */
+static gint
+lua_fann_get_outputs (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+
+       if (f != NULL) {
+               lua_pushnumber (L, fann_get_num_output (f));
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:save(fname)
+ * Save fann to file named 'fname'
+ * @param {string} fname filename to save fann into
+ * @return {boolean} true if ann has been saved
+ */
+static gint
+lua_fann_save (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       const gchar *fname = luaL_checkstring (L, 2);
+
+       if (f != NULL && fname != NULL) {
+               if (fann_save (f, fname) == 0) {
+                       lua_pushboolean (L, true);
+               }
+               else {
+                       msg_err ("cannot save ANN to %s: %s", fname, strerror (errno));
+                       lua_pushboolean (L, false);
+               }
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
+static gint
+lua_fann_dtor (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+
+       if (f) {
+               fann_destroy (f);
+       }
+
+       return 0;
+#endif
+}
+
+static gint
+lua_load_fann (lua_State * L)
+{
+       lua_newtable (L);
+       luaL_register (L, NULL, fannlib_f);
+
+       return 1;
+}
+
+void
+luaopen_fann (lua_State * L)
+{
+       rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m);
+       lua_pop (L, 1);
+
+       rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann);
+}