]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Added coroutines and async events support to rspamadm commands
authorMikhail Galanin <mgalanin@mimecast.com>
Wed, 5 Sep 2018 12:39:53 +0000 (13:39 +0100)
committerMikhail Galanin <mgalanin@mimecast.com>
Wed, 5 Sep 2018 12:39:53 +0000 (13:39 +0100)
src/rspamadm/lua_repl.c
src/rspamadm/rspamadm.c
src/rspamadm/rspamadm.h

index 26b642871a912f4d19c9ac906d429a41cb7b52e8..39fa0748080129605000a8a740bf4da7fc2bf68e 100644 (file)
@@ -20,6 +20,7 @@
 #include "libutil/http_private.h"
 #include "printf.h"
 #include "lua/lua_common.h"
+#include "lua/lua_thread_pool.h"
 #include "message.h"
 #include "unix-std.h"
 #include "linenoise.h"
@@ -37,6 +38,7 @@ static gchar *serve = NULL;
 static gchar *exec_line = NULL;
 static gint batch = -1;
 static gboolean per_line = FALSE;
+extern struct rspamd_async_session *rspamadm_session;
 
 static const char *default_history_file = ".rspamd_repl.hist";
 
@@ -70,10 +72,23 @@ struct rspamadm_lua_dot_command {
        rspamadm_lua_dot_handler handler;
 };
 
+struct lua_call_data {
+       gint top;
+       gint ret;
+       gpointer ud;
+};
+
 static void rspamadm_lua_help_handler (lua_State *L, gint argc, gchar **argv);
 static void rspamadm_lua_load_handler (lua_State *L, gint argc, gchar **argv);
 static void rspamadm_lua_exec_handler (lua_State *L, gint argc, gchar **argv);
 static void rspamadm_lua_message_handler (lua_State *L, gint argc, gchar **argv);
+static void lua_execute_and_wait (gint narg);
+
+static void lua_thread_error_cb (struct thread_entry *thread, int ret, const char *msg);
+static void lua_thread_finish_cb (struct thread_entry *thread, int ret);
+static gint lua_repl_thread_call (struct thread_entry *thread, gint narg,
+               gpointer ud, lua_thread_error_t error_func);
+
 
 static struct rspamadm_lua_dot_command cmds[] = {
        {
@@ -172,16 +187,41 @@ rspamadm_lua_add_path (lua_State *L, const gchar *path)
        g_string_free (new_path, TRUE);
 }
 
+
+static void
+lua_thread_finish_cb (struct thread_entry *thread, int ret)
+{
+       struct lua_call_data *cd = thread->cd;
+
+       cd->ret = ret;
+}
+
+static void
+lua_thread_error_cb (struct thread_entry *thread, int ret, const char *msg)
+{
+       struct lua_call_data *cd = thread->cd;
+
+       rspamd_fprintf (stderr, "call failed: %s\n", msg);
+
+       cd->ret = ret;
+}
+
+static void
+lua_thread_str_error_cb (struct thread_entry *thread, int ret, const char *msg)
+{
+       struct lua_call_data *cd = thread->cd;
+       const char *what = cd->ud;
+
+       rspamd_fprintf (stderr, "call to %s failed: %s\n", what, msg);
+
+       cd->ret = ret;
+}
+
 static gboolean
 rspamadm_lua_load_script (lua_State *L, const gchar *path)
 {
-       GString *tb;
-       gint err_idx = 0;
-
-       if (!per_line) {
-               lua_pushcfunction (L, &rspamd_lua_traceback);
-               err_idx = lua_gettop (L);
-       }
+       struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+       L = thread->lua_state;
 
        if (luaL_loadfile (L, path) != 0) {
                rspamd_fprintf (stderr, "cannot load script %s: %s\n",
@@ -192,12 +232,8 @@ rspamadm_lua_load_script (lua_State *L, const gchar *path)
        }
 
        if (!per_line) {
-               if (lua_pcall (L, 0, 0, err_idx) != 0) {
-                       tb = lua_touserdata (L, -1);
-                       rspamd_fprintf (stderr, "call to %s failed: %v", path, tb);
-                       g_string_free (tb, TRUE);
-                       lua_settop (L, 0);
 
+               if (lua_repl_thread_call (thread, 0, (void *)path, lua_thread_str_error_cb) != 0) {
                        return FALSE;
                }
 
@@ -211,27 +247,29 @@ static void
 rspamadm_exec_input (lua_State *L, const gchar *input)
 {
        GString *tb;
-       gint err_idx, i, cbref;
+       gint i, cbref;
+       int top = 0;
        gchar outbuf[8192];
 
-       lua_pushcfunction (L, &rspamd_lua_traceback);
-       err_idx = lua_gettop (L);
+       struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+       L = thread->lua_state;
 
        /* First try return + input */
        tb = g_string_sized_new (strlen (input) + sizeof ("return "));
        rspamd_printf_gstring (tb, "return %s", input);
 
-       if (luaL_loadstring (L, tb->str) != 0) {
+       int r = luaL_loadstring (L, tb->str);
+       if (r != 0) {
                /* Reset stack */
                lua_settop (L, 0);
-               lua_pushcfunction (L, &rspamd_lua_traceback);
-               err_idx = lua_gettop (L);
                /* Try with no return */
                if (luaL_loadstring (L, input) != 0) {
                        rspamd_fprintf (stderr, "cannot load string %s\n",
                                        input);
                        g_string_free (tb, TRUE);
                        lua_settop (L, 0);
+
+                       lua_thread_pool_return (rspamd_main->cfg->lua_thread_pool, thread);
                        return;
                }
        }
@@ -239,31 +277,63 @@ rspamadm_exec_input (lua_State *L, const gchar *input)
        g_string_free (tb, TRUE);
 
        if (!per_line) {
-               if (lua_pcall (L, 0, LUA_MULTRET, err_idx) != 0) {
-                       tb = lua_touserdata (L, -1);
-                       rspamd_fprintf (stderr, "call failed: %v\n", tb);
-                       g_string_free (tb, TRUE);
-                       lua_settop (L, 0);
-                       return;
-               }
 
-               /* Print output */
-               for (i = err_idx + 1; i <= lua_gettop (L); i++) {
-                       if (lua_isfunction (L, i)) {
-                               lua_pushvalue (L, i);
-                               cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+               top = lua_gettop (L);
+
+               if (lua_repl_thread_call (thread, 0, NULL, NULL) == 0) {
+                       /* Print output */
+                       for (i = top; i <= lua_gettop (L); i++) {
+                               if (lua_isfunction (L, i)) {
+                                       lua_pushvalue (L, i);
+                                       cbref = luaL_ref (L, LUA_REGISTRYINDEX);
 
-                               rspamd_printf ("local function: %d\n", cbref);
-                       } else {
-                               lua_logger_out_type (L, i, outbuf, sizeof (outbuf));
-                               rspamd_printf ("%s\n", outbuf);
+                                       rspamd_printf ("local function: %d\n", cbref);
+                               } else {
+                                       lua_logger_out_type (L, i, outbuf, sizeof (outbuf));
+                                       rspamd_printf ("%s\n", outbuf);
+                               }
                        }
                }
+       }
+}
 
-               lua_settop (L, 0);
+void
+wait_session_events ()
+{
+       /* XXX: it's probably worth to add timeout here - not to wait forever */
+       while (rspamd_session_events_pending (rspamadm_session) > 0) {
+               event_base_loop (rspamd_main->ev_base, EVLOOP_ONCE);
        }
 }
 
+static gint
+lua_repl_thread_call (struct thread_entry *thread, gint narg, gpointer ud, lua_thread_error_t error_func)
+{
+       int ret;
+       struct lua_call_data *cd = g_new0 (struct lua_call_data, 1);
+       cd->top = lua_gettop (L);
+       cd->ud = ud;
+
+       thread->finish_callback = lua_thread_finish_cb;
+       if (error_func) {
+               thread->error_callback = error_func;
+       }
+       else {
+               thread->error_callback = lua_thread_error_cb;
+       }
+       thread->cd = cd;
+
+       lua_thread_call (thread, narg);
+
+       wait_session_events ();
+
+       ret = cd->ret;
+
+       g_free (cd);
+
+       return ret;
+}
+
 static void
 rspamadm_lua_help_handler (lua_State *L, gint argc, gchar **argv)
 {
@@ -308,12 +378,12 @@ rspamadm_lua_load_handler (lua_State *L, gint argc, gchar **argv)
 static void
 rspamadm_lua_exec_handler (lua_State *L, gint argc, gchar **argv)
 {
-       GString *tb;
-       gint err_idx, i;
+       gint i;
+
+       struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+       L = thread->lua_state;
 
        for (i = 1; argv[i] != NULL; i ++) {
-               lua_pushcfunction (L, &rspamd_lua_traceback);
-               err_idx = lua_gettop (L);
 
                if (luaL_loadfile (L, argv[i]) != 0) {
                        rspamd_fprintf (stderr, "cannot load script %s: %s\n",
@@ -323,16 +393,7 @@ rspamadm_lua_exec_handler (lua_State *L, gint argc, gchar **argv)
                        return;
                }
 
-               if (lua_pcall (L, 0, 0, err_idx) != 0) {
-                       tb = lua_touserdata (L, -1);
-                       rspamd_fprintf (stderr, "call to %s failed: %v", argv[i], tb);
-                       g_string_free (tb, TRUE);
-                       lua_settop (L, 0);
-
-                       return;
-               }
-
-               lua_settop (L, 0);
+               lua_repl_thread_call (thread, 0, argv[i], lua_thread_str_error_cb);
        }
 }
 
@@ -340,11 +401,10 @@ static void
 rspamadm_lua_message_handler (lua_State *L, gint argc, gchar **argv)
 {
        gulong cbref;
-       gint err_idx, func_idx, i, j;
+       gint old_top, func_idx, i, j;
        struct rspamd_task *task, **ptask;
        gpointer map;
        gsize len;
-       GString *tb;
        gchar outbuf[8192];
 
        if (argv[1] == NULL) {
@@ -352,22 +412,26 @@ rspamadm_lua_message_handler (lua_State *L, gint argc, gchar **argv)
                return;
        }
 
-       if (rspamd_strtoul (argv[1], strlen (argv[1]), &cbref)) {
-               lua_rawgeti (L, LUA_REGISTRYINDEX, cbref);
-       }
-       else {
-               lua_getglobal (L, argv[1]);
-       }
+       for (i = 2; argv[i] != NULL; i ++) {
+               struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+               L = thread->lua_state;
 
-       if (lua_type (L, -1) != LUA_TFUNCTION) {
-               rspamd_printf ("bad callback type: %s\n", lua_typename (L, lua_type (L, -1)));
-               return;
-       }
+               if (rspamd_strtoul (argv[1], strlen (argv[1]), &cbref)) {
+                       lua_rawgeti (L, LUA_REGISTRYINDEX, cbref);
+               }
+               else {
+                       lua_getglobal (L, argv[1]);
+               }
 
-       /* Save index to reuse */
-       func_idx = lua_gettop (L);
+               if (lua_type (L, -1) != LUA_TFUNCTION) {
+                       rspamd_printf ("bad callback type: %s\n", lua_typename (L, lua_type (L, -1)));
+                       lua_thread_pool_return (rspamd_main->cfg->lua_thread_pool, thread);
+                       return;
+               }
+
+               /* Save index to reuse */
+               func_idx = lua_gettop (L);
 
-       for (i = 2; argv[i] != NULL; i ++) {
                map = rspamd_file_xmap (argv[i], PROT_READ, &len, TRUE);
 
                if (map == NULL) {
@@ -391,23 +455,18 @@ rspamadm_lua_message_handler (lua_State *L, gint argc, gchar **argv)
                        }
 
                        rspamd_message_process (task);
-                       lua_pushcfunction (L, &rspamd_lua_traceback);
-                       err_idx = lua_gettop (L);
+                       old_top = lua_gettop (L);
 
                        lua_pushvalue (L, func_idx);
                        ptask = lua_newuserdata (L, sizeof (*ptask));
                        *ptask = task;
                        rspamd_lua_setclass (L, "rspamd{task}", -1);
 
-                       if (lua_pcall (L, 1, LUA_MULTRET, err_idx) != 0) {
-                               tb = lua_touserdata (L, -1);
-                               rspamd_printf ("lua callback for %s failed: %v\n", argv[i], tb);
-                               g_string_free (tb, TRUE);
-                       }
-                       else {
+
+                       if (lua_repl_thread_call (thread, 1, argv[i], lua_thread_str_error_cb) == 0) {
                                rspamd_printf ("lua callback for %s returned:\n", argv[i]);
 
-                               for (j = err_idx + 1; j <= lua_gettop (L); j ++) {
+                               for (j = old_top + 1; j <= lua_gettop (L); j ++) {
                                        lua_logger_out_type (L, j, outbuf, sizeof (outbuf));
                                        rspamd_printf ("%s\n", outbuf);
                                }
@@ -578,6 +637,18 @@ rspamadm_lua_finish_handler (struct rspamd_http_connection_entry *conn_ent)
        g_free (session);
 }
 
+static void
+lua_thread_http_error_cb (struct thread_entry *thread, int ret, const char *msg)
+{
+       struct lua_call_data *cd = thread->cd;
+       struct rspamd_http_connection_entry *conn_ent = cd->ud;
+
+       rspamd_controller_send_error (conn_ent, 500, "call failed: %s\n", msg);
+
+       cd->ret = ret;
+}
+
+
 /*
  * Exec command handler:
  * request: /exec
@@ -598,7 +669,10 @@ rspamadm_lua_handle_exec (struct rspamd_http_connection_entry *conn_ent,
        gsize body_len;
 
        ctx = session->ctx;
-       L = ctx->L;
+
+       struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+       L = thread->lua_state;
+
        body = rspamd_http_message_get_body (msg, &body_len);
 
        if (body == NULL) {
@@ -629,12 +703,7 @@ rspamadm_lua_handle_exec (struct rspamd_http_connection_entry *conn_ent,
 
        g_string_free (tb, TRUE);
 
-       if (lua_pcall (L, 0, LUA_MULTRET, err_idx) != 0) {
-               tb = lua_touserdata (L, -1);
-               rspamd_controller_send_error (conn_ent, 500, "call failed: %v\n", tb);
-               g_string_free (tb, TRUE);
-               lua_settop (L, 0);
-
+       if (lua_repl_thread_call (thread, 0, conn_ent, lua_thread_http_error_cb) != 0) {
                return 0;
        }
 
@@ -830,12 +899,10 @@ again:
                        lua_pushlstring (L, buf->str, MIN (buf->len, end_pos));
                        lua_setglobal (L, "input");
 
-                       if (lua_pcall (L, 0, 0, 0) != 0) {
-                               rspamd_fprintf (stderr, "call to script failed: %s",
-                                               lua_tostring (L, -1));
-                               lua_settop (L, 0);
-                               break;
-                       }
+                       struct thread_entry *thread = lua_thread_pool_get_for_config (rspamd_main->cfg);
+                       L = thread->lua_state;
+
+                       lua_repl_thread_call (thread, 0, NULL, NULL);
 
                        lua_settop (L, old_top);
                }
index d5a656a516f45facfd539b0d31b4574a3c7fe3a4..092e4ff58e8d8659ca610aa3dbb4f01fde3a8a68 100644 (file)
@@ -31,6 +31,7 @@ static gboolean show_help = FALSE;
 static gboolean show_version = FALSE;
 GHashTable *ucl_vars = NULL;
 struct rspamd_main *rspamd_main = NULL;
+struct rspamd_async_session *rspamadm_session = NULL;
 lua_State *L = NULL;
 
 /* Defined in modules.c */
@@ -317,6 +318,28 @@ rspamadm_command_maybe_match_name (const gchar *cmd, const gchar *input)
        return FALSE;
 }
 
+
+
+static void
+rspamadm_add_lua_globals()
+{
+       struct rspamd_async_session  **psession;
+       struct event_base **pev_base;
+
+       rspamadm_session = rspamd_session_create (rspamd_main->cfg->cfg_pool, NULL,
+                       NULL, (event_finalizer_t )NULL, NULL);
+
+       psession = lua_newuserdata (L, sizeof (struct rspamd_async_session*));
+       rspamd_lua_setclass (L, "rspamd{session}", -1);
+       *psession = rspamadm_session;
+       lua_setglobal (L, "rspamadm_session");
+
+       pev_base = lua_newuserdata (L, sizeof (struct event_base *));
+       rspamd_lua_setclass (L, "rspamd{ev_base}", -1);
+       *pev_base = rspamd_main->ev_base;
+       lua_setglobal (L, "rspamadm_ev_base");
+}
+
 gint
 main (gint argc, gchar **argv, gchar **env)
 {
@@ -343,6 +366,8 @@ main (gint argc, gchar **argv, gchar **env)
        rspamd_main->type = process_quark;
        rspamd_main->server_pool = rspamd_mempool_new (rspamd_mempool_suggest_size (),
                        "rspamadm");
+       rspamd_main->ev_base = event_init ();
+
        rspamadm_fill_internal_commands (all_commands);
        help_command.command_data = all_commands;
 
@@ -417,6 +442,7 @@ main (gint argc, gchar **argv, gchar **env)
        L = cfg->lua_state;
        rspamd_lua_set_path (L, NULL, ucl_vars);
        rspamd_lua_set_globals (cfg, L, ucl_vars);
+       rspamadm_add_lua_globals();
 
        /* Init rspamadm global */
        lua_newtable (L);
@@ -498,6 +524,8 @@ main (gint argc, gchar **argv, gchar **env)
                cmd->run (0, NULL, cmd);
        }
 
+       event_base_loopexit (rspamd_main->ev_base, NULL);
+
        REF_RELEASE (rspamd_main->cfg);
        rspamd_log_close (rspamd_main->logger, TRUE);
        g_free (rspamd_main);
index 02ecb2f479526ed08fb74f3c0421c867d898add3..ff12849425ec8525110161ad06df2c56a829b608 100644 (file)
@@ -24,6 +24,7 @@
 
 extern GHashTable *ucl_vars;
 extern lua_State *L;
+extern struct rspamd_main *rspamd_main;
 
 GQuark rspamadm_error (void);