]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Use pointer set instead of key map for task validation
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 22 Dec 2025 10:06:19 +0000 (10:06 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 22 Dec 2025 10:06:19 +0000 (10:06 +0000)
Store task pointers in a khash set and validate them on lookup
from Lua. This works with all code paths that create task userdata
directly without going through rspamd_lua_task_push.

src/libserver/task.c
src/libserver/task.h
src/lua/lua_task.c

index 27b15b3a63f42e449c515cf4e8dd482887a8d3a8..c722c73be1ff4df229f997dc80fa60d09ac7d50b 100644 (file)
@@ -51,41 +51,35 @@ __KHASH_IMPL(rspamd_req_headers_hash, static inline,
                         rspamd_ftok_t *, struct rspamd_request_header_chain *, 1,
                         rspamd_ftok_icase_hash, rspamd_ftok_icase_equal)
 
-/* Task registry: maps lua_key -> task pointer for safe Lua references */
-KHASH_INIT(rspamd_task_registry, uint64_t, struct rspamd_task *, 1,
-                  kh_int64_hash_func, kh_int64_hash_equal);
+/* Task pointer set for validating Lua task references */
+KHASH_SET_INIT_INT64(rspamd_task_set);
 
-static khash_t(rspamd_task_registry) *task_registry = NULL;
-static uint64_t task_lua_key_counter = 0;
+static khash_t(rspamd_task_set) *task_registry = NULL;
 
 void rspamd_task_registry_init(void)
 {
        if (task_registry == NULL) {
-               task_registry = kh_init(rspamd_task_registry);
+               task_registry = kh_init(rspamd_task_set);
        }
 }
 
 void rspamd_task_registry_destroy(void)
 {
        if (task_registry != NULL) {
-               kh_destroy(rspamd_task_registry, task_registry);
+               kh_destroy(rspamd_task_set, task_registry);
                task_registry = NULL;
        }
 }
 
-struct rspamd_task *
-rspamd_task_by_lua_key(uint64_t lua_key)
+gboolean
+rspamd_task_is_valid(struct rspamd_task *task)
 {
-       if (task_registry == NULL || lua_key == 0) {
-               return NULL;
-       }
-
-       khiter_t k = kh_get(rspamd_task_registry, task_registry, lua_key);
-       if (k != kh_end(task_registry)) {
-               return kh_value(task_registry, k);
+       if (task_registry == NULL || task == NULL) {
+               return FALSE;
        }
 
-       return NULL;
+       khiter_t k = kh_get(rspamd_task_set, task_registry, (uint64_t) (uintptr_t) task);
+       return k != kh_end(task_registry);
 }
 
 static inline void
@@ -95,28 +89,21 @@ rspamd_task_registry_add(struct rspamd_task *task)
                rspamd_task_registry_init();
        }
 
-       task->lua_key = ++task_lua_key_counter;
-
        int ret;
-       khiter_t k = kh_put(rspamd_task_registry, task_registry, task->lua_key, &ret);
-       if (ret > 0) {
-               kh_value(task_registry, k) = task;
-       }
+       kh_put(rspamd_task_set, task_registry, (uint64_t) (uintptr_t) task, &ret);
 }
 
 static inline void
 rspamd_task_registry_remove(struct rspamd_task *task)
 {
-       if (task_registry == NULL || task->lua_key == 0) {
+       if (task_registry == NULL) {
                return;
        }
 
-       khiter_t k = kh_get(rspamd_task_registry, task_registry, task->lua_key);
+       khiter_t k = kh_get(rspamd_task_set, task_registry, (uint64_t) (uintptr_t) task);
        if (k != kh_end(task_registry)) {
-               kh_del(rspamd_task_registry, task_registry, k);
+               kh_del(rspamd_task_set, task_registry, k);
        }
-
-       task->lua_key = 0;
 }
 
 static GQuark
index b3daa5ab6e6b3d40164c9f8f3dc91ad186e81796..da7ffb5b2272d3c1607bcbf2ecdd2a1148df6388 100644 (file)
@@ -168,7 +168,6 @@ KHASH_INIT(rspamd_task_lua_cache, char *, struct rspamd_lua_cached_entry, 1, kh_
  */
 struct rspamd_task {
        struct rspamd_worker *worker; /**< pointer to worker object                                             */
-       uint64_t lua_key;             /**< unique key for Lua task registry                             */
        enum rspamd_command cmd;      /**< command                                                                              */
        int sock;                     /**< socket descriptor                                                            */
        uint32_t dns_requests;        /**< number of DNS requests per this task                 */
@@ -418,11 +417,11 @@ void rspamd_task_timeout(EV_P_ ev_timer *w, int revents);
 void rspamd_worker_guard_handler(EV_P_ ev_io *w, int revents);
 
 /*
- * Task registry for safe Lua task references
+ * Task validity set for safe Lua task references
  */
 void rspamd_task_registry_init(void);
 void rspamd_task_registry_destroy(void);
-struct rspamd_task *rspamd_task_by_lua_key(uint64_t lua_key);
+gboolean rspamd_task_is_valid(struct rspamd_task *task);
 
 #ifdef __cplusplus
 }
index 613d7f080d217cc3655477e218cf44a1eface3af..92b588596c4e7ba087dc3be6cb4b96b243d5f0c5 100644 (file)
@@ -1493,8 +1493,10 @@ lua_check_task(lua_State *L, int pos)
        void *ud = rspamd_lua_check_udata(L, pos, rspamd_task_classname);
        luaL_argcheck(L, ud != NULL, pos, "'task' expected");
        if (ud) {
-               uint64_t lua_key = *((uint64_t *) ud);
-               return rspamd_task_by_lua_key(lua_key);
+               struct rspamd_task *task = *((struct rspamd_task **) ud);
+               if (rspamd_task_is_valid(task)) {
+                       return task;
+               }
        }
        return NULL;
 }
@@ -1505,8 +1507,10 @@ lua_check_task_maybe(lua_State *L, int pos)
        void *ud = rspamd_lua_check_udata_maybe(L, pos, rspamd_task_classname);
 
        if (ud) {
-               uint64_t lua_key = *((uint64_t *) ud);
-               return rspamd_task_by_lua_key(lua_key);
+               struct rspamd_task *task = *((struct rspamd_task **) ud);
+               if (rspamd_task_is_valid(task)) {
+                       return task;
+               }
        }
        return NULL;
 }
@@ -8094,9 +8098,9 @@ void luaopen_image(lua_State *L)
 
 void rspamd_lua_task_push(lua_State *L, struct rspamd_task *task)
 {
-       uint64_t *pkey;
+       struct rspamd_task **ptask;
 
-       pkey = lua_newuserdata(L, sizeof(uint64_t));
+       ptask = lua_newuserdata(L, sizeof(gpointer));
        rspamd_lua_setclass(L, rspamd_task_classname, -1);
-       *pkey = task->lua_key;
+       *ptask = task;
 }