]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add neural net serialization/deserialization
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 8 Oct 2016 12:42:56 +0000 (13:42 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 8 Oct 2016 15:35:42 +0000 (16:35 +0100)
src/lua/lua_fann.c

index 6df5d9470223030d4541692763737860c6170c3f..3d15c64170edf437aed5791db1f4d0f919692e5a 100644 (file)
@@ -19,6 +19,8 @@
 #include <fann.h>
 #endif
 
+#include "unix-std.h"
+
 /***
  * @module rspamd_fann
  * This module enables [fann](http://libfann.github.io) interaction in rspamd
@@ -31,7 +33,8 @@
  */
 LUA_FUNCTION_DEF (fann, is_enabled);
 LUA_FUNCTION_DEF (fann, create);
-LUA_FUNCTION_DEF (fann, load);
+LUA_FUNCTION_DEF (fann, load_file);
+LUA_FUNCTION_DEF (fann, load_data);
 
 /*
  * Fann methods
@@ -39,6 +42,7 @@ LUA_FUNCTION_DEF (fann, load);
 LUA_FUNCTION_DEF (fann, train);
 LUA_FUNCTION_DEF (fann, test);
 LUA_FUNCTION_DEF (fann, save);
+LUA_FUNCTION_DEF (fann, data);
 LUA_FUNCTION_DEF (fann, get_inputs);
 LUA_FUNCTION_DEF (fann, get_outputs);
 LUA_FUNCTION_DEF (fann, dtor);
@@ -46,7 +50,9 @@ LUA_FUNCTION_DEF (fann, dtor);
 static const struct luaL_reg fannlib_f[] = {
                LUA_INTERFACE_DEF (fann, is_enabled),
                LUA_INTERFACE_DEF (fann, create),
-               LUA_INTERFACE_DEF (fann, load),
+               LUA_INTERFACE_DEF (fann, load_file),
+               {"load", lua_fann_load_file},
+               LUA_INTERFACE_DEF (fann, load_data),
                {NULL, NULL}
 };
 
@@ -54,6 +60,7 @@ static const struct luaL_reg fannlib_m[] = {
                LUA_INTERFACE_DEF (fann, train),
                LUA_INTERFACE_DEF (fann, test),
                LUA_INTERFACE_DEF (fann, save),
+               LUA_INTERFACE_DEF (fann, data),
                LUA_INTERFACE_DEF (fann, get_inputs),
                LUA_INTERFACE_DEF (fann, get_outputs),
                {"__gc", lua_fann_dtor},
@@ -141,7 +148,7 @@ lua_fann_create (lua_State *L)
  * @return {fann} fann object
  */
 static gint
-lua_fann_load (lua_State *L)
+lua_fann_load_file (lua_State *L)
 {
 #ifndef WITH_FANN
        return 0;
@@ -171,6 +178,135 @@ lua_fann_load (lua_State *L)
 #endif
 }
 
+/***
+ * @function rspamd_fann.load_data(data)
+ * Loads neural network from the data
+ * @param {string} file filename where fann is stored
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_load_data (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f, **pfann;
+       gint fd;
+       struct rspamd_lua_text *t;
+       gchar fpath[PATH_MAX];
+
+       if (lua_type (L, 1) == LUA_TUSERDATA) {
+               t = lua_check_text (L, 1);
+
+               if (!t) {
+                       return luaL_error (L, "text required");
+               }
+       }
+       else {
+               t = g_alloca (sizeof (*t));
+               t->start = lua_tolstring (L, 1, (gsize *)&t->len);
+               t->flags = 0;
+       }
+
+       /* We need to save data to file because of libfann stupidity */
+       rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath));
+       fd = mkstemp (fpath);
+
+       if (fd == -1) {
+               msg_warn ("cannot create tempfile: %s", strerror (errno));
+               lua_pushnil (L);
+       }
+       else {
+               if (write (fd, t->start, t->len) == -1) {
+                       msg_warn ("cannot write tempfile: %s", strerror (errno));
+                       lua_pushnil (L);
+                       unlink (fpath);
+                       close (fd);
+
+                       return 1;
+               }
+
+               f = fann_create_from_file (fpath);
+               unlink (fpath);
+               close (fd);
+
+               if (f != NULL) {
+                       pfann = lua_newuserdata (L, sizeof (gpointer));
+                       *pfann = f;
+                       rspamd_lua_setclass (L, "rspamd{fann}", -1);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+       }
+
+       return 1;
+#endif
+}
+
+/***
+ * @function rspamd_fann:data()
+ * Returns serialized neural network
+ * @return {rspamd_text} fann data
+ */
+static gint
+lua_fann_data (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       gint fd;
+       struct rspamd_lua_text *res;
+       gchar fpath[PATH_MAX];
+       gpointer map;
+       gsize sz;
+
+       if (f == NULL) {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       /* We need to save data to file because of libfann stupidity */
+       rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath));
+       fd = mkstemp (fpath);
+
+       if (fd == -1) {
+               msg_warn ("cannot create tempfile: %s", strerror (errno));
+               lua_pushnil (L);
+       }
+       else {
+               if (fann_save (f, fpath) == -1) {
+                       msg_warn ("cannot write tempfile: %s", strerror (errno));
+                       lua_pushnil (L);
+                       unlink (fpath);
+                       close (fd);
+
+                       return 1;
+               }
+
+
+               (void)lseek (fd, 0, SEEK_SET);
+               map = rspamd_file_xmap (fpath, PROT_READ, &sz);
+               unlink (fpath);
+               close (fd);
+
+               if (map != NULL) {
+                       res = lua_newuserdata (L, sizeof (*res));
+                       res->len = sz;
+                       res->start = map;
+                       res->flags = RSPAMD_TEXT_FLAG_OWN|RSPAMD_TEXT_FLAG_MMAPED;
+                       rspamd_lua_setclass (L, "rspamd{text}", -1);
+               }
+               else {
+                       lua_pushnil (L);
+               }
+
+       }
+
+       return 1;
+#endif
+}
+
 
 /**
  * @method rspamd_fann:train(inputs, outputs)