]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Lua_task: Allow to load data into the existing task
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 20 Dec 2024 17:31:26 +0000 (17:31 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 20 Dec 2024 17:31:26 +0000 (17:31 +0000)
src/lua/lua_task.c

index b368ad4e65d73b0535afcf221b217f6437ee997e..3556808819b3087eac5bee803bcba8c3beb7ddbc 100644 (file)
@@ -1233,6 +1233,8 @@ static const struct luaL_reg tasklib_f[] = {
        {NULL, NULL}};
 
 static const struct luaL_reg tasklib_m[] = {
+       LUA_INTERFACE_DEF(task, load_from_file),
+       LUA_INTERFACE_DEF(task, load_from_string),
        LUA_INTERFACE_DEF(task, get_message),
        LUA_INTERFACE_DEF(task, set_message),
        LUA_INTERFACE_DEF(task, destroy),
@@ -1724,20 +1726,32 @@ lua_task_load_from_file(lua_State *L)
 {
        LUA_TRACE_POINT;
        struct rspamd_task *task = NULL, **ptask;
-       const char *fname = luaL_checkstring(L, 1), *err = NULL;
+       const char *fname, *err = NULL;
        struct rspamd_config *cfg = NULL;
-       gboolean res = FALSE;
+       gboolean res = FALSE, new_task = FALSE;
        gpointer map;
        gsize sz;
 
+       if (lua_type(L, 1) == LUA_TSTRING) {
+               fname = luaL_checkstring(L, 1);
+               new_task = TRUE;
+       }
+       else {
+               /* Method */
+               task = lua_check_task(L, 1);
+               fname = luaL_checkstring(L, 2);
+       }
+
        if (fname) {
 
-               if (lua_type(L, 2) == LUA_TUSERDATA) {
-                       gpointer p;
-                       p = rspamd_lua_check_udata_maybe(L, 2, rspamd_config_classname);
+               if (!task) {
+                       if (lua_type(L, 2) == LUA_TUSERDATA) {
+                               gpointer p;
+                               p = rspamd_lua_check_udata_maybe(L, 2, rspamd_config_classname);
 
-                       if (p) {
-                               cfg = *(struct rspamd_config **) p;
+                               if (p) {
+                                       cfg = *(struct rspamd_config **) p;
+                               }
                        }
                }
 
@@ -1763,11 +1777,17 @@ lua_task_load_from_file(lua_State *L)
                                }
                        }
 
-                       task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+                       if (!task) {
+                               task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+                       }
+
                        task->msg.begin = data->str;
                        task->msg.len = data->len;
                        rspamd_mempool_add_destructor(task->task_pool,
                                                                                  lua_task_free_dtor, data->str);
+                       if (data->len > 0) {
+                               task->flags &= ~RSPAMD_TASK_FLAG_EMPTY;
+                       }
                        res = TRUE;
                        g_string_free(data, FALSE); /* Buffer is still valid */
                }
@@ -1778,9 +1798,16 @@ lua_task_load_from_file(lua_State *L)
                                err = strerror(errno);
                        }
                        else {
-                               task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+                               if (!task) {
+                                       task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+                               }
+
                                task->msg.begin = map;
                                task->msg.len = sz;
+
+                               if (sz > 0) {
+                                       task->flags &= ~RSPAMD_TASK_FLAG_EMPTY;
+                               }
                                rspamd_mempool_add_destructor(task->task_pool,
                                                                                          lua_task_unmap_dtor, task);
                                res = TRUE;
@@ -1793,21 +1820,26 @@ lua_task_load_from_file(lua_State *L)
 
        lua_pushboolean(L, res);
 
-       if (res) {
+       if (res && new_task) {
                ptask = lua_newuserdata(L, sizeof(*ptask));
                *ptask = task;
                rspamd_lua_setclass(L, rspamd_task_classname, -1);
+
+               return 2;
        }
-       else {
+       else if (!res) {
                if (err) {
                        lua_pushstring(L, err);
                }
                else {
                        lua_pushnil(L);
                }
+               return 2;
+       }
+       else {
+               /* No new task */
+               return 1;
        }
-
-       return 2;
 }
 
 static int
@@ -1816,14 +1848,23 @@ lua_task_load_from_string(lua_State *L)
        LUA_TRACE_POINT;
        struct rspamd_task *task = NULL, **ptask;
        const char *str_message;
-       gsize message_len;
+       gsize message_len = 0;
        struct rspamd_config *cfg = NULL;
+       bool new_task = false;
 
-       str_message = luaL_checklstring(L, 1, &message_len);
+       if (lua_type(L, 1) == LUA_TSTRING) {
+               str_message = luaL_checklstring(L, 1, &message_len);
+               new_task = true;
+       }
+       else {
+               /* Method */
+               task = lua_check_task(L, 1);
+               str_message = luaL_checklstring(L, 2, &message_len);
+       }
 
        if (str_message) {
 
-               if (lua_type(L, 2) == LUA_TUSERDATA) {
+               if (!task && lua_type(L, 2) == LUA_TUSERDATA) {
                        gpointer p;
                        p = rspamd_lua_check_udata_maybe(L, 2, rspamd_config_classname);
 
@@ -1832,10 +1873,15 @@ lua_task_load_from_string(lua_State *L)
                        }
                }
 
-               task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+               if (!task) {
+                       task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE);
+               }
                task->msg.begin = g_malloc(message_len);
                memcpy((char *) task->msg.begin, str_message, message_len);
                task->msg.len = message_len;
+               if (message_len > 0) {
+                       task->flags &= ~RSPAMD_TASK_FLAG_EMPTY;
+               }
                rspamd_mempool_add_destructor(task->task_pool, lua_task_free_dtor,
                                                                          (gpointer) task->msg.begin);
        }
@@ -1845,11 +1891,16 @@ lua_task_load_from_string(lua_State *L)
 
        lua_pushboolean(L, true);
 
-       ptask = lua_newuserdata(L, sizeof(*ptask));
-       *ptask = task;
-       rspamd_lua_setclass(L, rspamd_task_classname, -1);
+       if (new_task) {
+               ptask = lua_newuserdata(L, sizeof(*ptask));
+               *ptask = task;
+               rspamd_lua_setclass(L, rspamd_task_classname, -1);
 
-       return 2;
+               return 2;
+       }
+       else {
+               return 1;
+       }
 }
 
 static int