]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] lua_task: bulk and regexp symbol lookups
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 12 May 2026 14:43:45 +0000 (15:43 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 12 May 2026 14:43:45 +0000 (15:43 +0100)
Add table-form overloads to task:has_symbol() and task:get_symbol()
that accept {S1, S2, ..., Sn} and return true / a {name -> info} map
if any of the listed symbols fired. Both keep the legacy single-name
form (with optional shadow_result_name) untouched.

Introduce task:has_symbol_regexp(re [, shadow_result_name]) and
task:get_symbol_regexp(re [, shadow_result_name]) that match fired
symbol names against an rspamd_regexp userdata.

src/lua/lua_task.c

index 97f017b513d387512f96a3d1fc0731720b33190c..a29be75fd4d0ee818a3f70672e2fb506205fbaab 100644 (file)
@@ -783,22 +783,42 @@ LUA_FUNCTION_DEF(task, get_archives);
 LUA_FUNCTION_DEF(task, get_dkim_results);
 /***
  * @method task:get_symbol(name, [shadow_result_name])
- * Searches for a symbol `name` in all metrics results and returns a list of tables
- * one per metric that describes the symbol inserted.
- * Please note, that for using this function you need to ensure that the symbol
- * being queried is already checked. This is guaranteed if there is a dependency
- * between the caller symbol and the checked symbol (either virtual or real).
- * Please check `rspamd_config:register_dependency` method for details.
- * The symbols are returned as the list of the following tables:
+ * @method task:get_symbol(names_table, [shadow_result_name])
+ * Searches for a symbol or symbols in scan results and returns the matching
+ * symbol descriptors. Each descriptor is a table with the fields:
  *
- * - `metric` - name of metric
+ * - `metric` - name of metric (only for the single-name positional form)
  * - `score` - score of a symbol in that metric
  * - `options` - a table of strings representing options of a symbol
  * - `group` - a group of symbol (or 'ungrouped')
- * @param {string} name symbol's name
- * @return {list of tables} list of tables or nil if symbol was not found
+ *
+ * Forms:
+ *
+ * - Single positional name (legacy): `get_symbol("FOO")` or
+ *   `get_symbol("FOO", "shadow_result")`. Returns a list with one table or
+ *   `nil` if not found (kept for backwards compatibility).
+ * - Table of names: `get_symbol({"FOO", "BAR"})` or
+ *   `get_symbol({"FOO", "BAR"}, "shadow_result")`. Returns a map
+ *   `{name -> info}` for the names that have fired, or `nil` if none did.
+ *
+ * Please note, that for using this function you need to ensure that the
+ * symbols being queried are already checked. This is guaranteed if there is
+ * a dependency between the caller symbol and the checked symbol (either
+ * virtual or real). Please check `rspamd_config:register_dependency` method
+ * for details.
+ * @return {table|nil} list/map of tables or nil if no symbol was found
  */
 LUA_FUNCTION_DEF(task, get_symbol);
+/***
+ * @method task:get_symbol_regexp(regexp, [shadow_result_name])
+ * Returns a map `{symbol_name -> info_table}` for every fired symbol whose
+ * name matches the supplied `rspamd_regexp` userdata. Returns `nil` when no
+ * symbol matched.
+ * @param {rspamd_regexp} regexp pattern to match symbol names against
+ * @param {string} shadow_result_name optional name of a shadow scan result
+ * @return {table|nil} map of matched symbols or nil
+ */
+LUA_FUNCTION_DEF(task, get_symbol_regexp);
 /***
  * @method task:get_symbols_all()
  * Returns array of symbols matched in default metric with all metadata
@@ -847,15 +867,35 @@ LUA_FUNCTION_DEF(task, process_ann_tokens);
 
 /***
  * @method task:has_symbol(name, [shadow_result_name])
- * Fast path to check if a specified symbol is in the task's results.
- * Please note, that for using this function you need to ensure that the symbol
- * being queried is already checked. This is guaranteed if there is a dependency
- * between the caller symbol and the checked symbol (either virtual or real).
- * Please check `rspamd_config:register_dependency` method for details.
- * @param {string} name symbol's name
- * @return {boolean} `true` if symbol has been found
+ * @method task:has_symbol(names_table, [shadow_result_name])
+ * Fast path to check if a specified symbol (or any of several symbols) is in
+ * the task's results.
+ *
+ * Forms:
+ *
+ * - Single positional name (legacy): `has_symbol("FOO")` or
+ *   `has_symbol("FOO", "shadow_result")`.
+ * - Table of names: `has_symbol({"FOO", "BAR"})` or
+ *   `has_symbol({"FOO", "BAR"}, "shadow_result")`. Returns `true` if any of
+ *   the names is present.
+ *
+ * Please note, that for using this function you need to ensure that the
+ * symbols being queried are already checked. This is guaranteed if there is
+ * a dependency between the caller symbol and the checked symbol (either
+ * virtual or real). Please check `rspamd_config:register_dependency` method
+ * for details.
+ * @return {boolean} `true` if (any of the) symbol(s) has been found
  */
 LUA_FUNCTION_DEF(task, has_symbol);
+/***
+ * @method task:has_symbol_regexp(regexp, [shadow_result_name])
+ * Checks whether any fired symbol has a name matching the supplied
+ * `rspamd_regexp` userdata.
+ * @param {rspamd_regexp} regexp pattern to match symbol names against
+ * @param {string} shadow_result_name optional name of a shadow scan result
+ * @return {boolean} `true` if at least one matching symbol has been found
+ */
+LUA_FUNCTION_DEF(task, has_symbol_regexp);
 /***
  * @method task:enable_symbol(name)
  * Enable specified symbol for this particular task
@@ -1412,6 +1452,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF(task, get_archives),
        LUA_INTERFACE_DEF(task, get_dkim_results),
        LUA_INTERFACE_DEF(task, get_symbol),
+       LUA_INTERFACE_DEF(task, get_symbol_regexp),
        LUA_INTERFACE_DEF(task, get_symbols),
        LUA_INTERFACE_DEF(task, get_symbols_all),
        LUA_INTERFACE_DEF(task, get_symbols_numeric),
@@ -1419,6 +1460,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF(task, get_groups),
        LUA_INTERFACE_DEF(task, process_ann_tokens),
        LUA_INTERFACE_DEF(task, has_symbol),
+       LUA_INTERFACE_DEF(task, has_symbol_regexp),
        LUA_INTERFACE_DEF(task, enable_symbol),
        LUA_INTERFACE_DEF(task, disable_symbol),
        LUA_INTERFACE_DEF(task, get_date),
@@ -5256,26 +5298,92 @@ lua_push_symbol_result(lua_State *L,
        return FALSE;
 }
 
+/*
+ * Resolve a shadow result name (Lua string) to an rspamd_scan_result.
+ * Returns NULL on error and pushes a Lua error via luaL_error; the caller
+ * should propagate that error. *out_sres is set on success (or left NULL
+ * when shadow_idx is not a string).
+ */
+static gboolean
+lua_task_resolve_shadow_result(lua_State *L,
+                                                          struct rspamd_task *task,
+                                                          int shadow_idx,
+                                                          struct rspamd_scan_result **out_sres)
+{
+       *out_sres = NULL;
+
+       if (lua_isstring(L, shadow_idx)) {
+               *out_sres = rspamd_find_metric_result(task, lua_tostring(L, shadow_idx));
+
+               if (*out_sres == NULL) {
+                       luaL_error(L, "invalid scan result: %s", lua_tostring(L, shadow_idx));
+                       return FALSE;
+               }
+       }
+
+       return TRUE;
+}
+
+/* Push a name+info pair into the map table currently at top of stack. */
+static inline void
+lua_task_symbol_push_into_map(lua_State *L,
+                                                         struct rspamd_task *task,
+                                                         const char *name,
+                                                         struct rspamd_symbol_result *s,
+                                                         struct rspamd_scan_result *sres,
+                                                         unsigned int *count)
+{
+       if (lua_push_symbol_result(L, task, name, s, sres, FALSE, FALSE)) {
+               lua_setfield(L, -2, name);
+               (*count)++;
+       }
+}
+
 static int
 lua_task_get_symbol(lua_State *L)
 {
        LUA_TRACE_POINT;
        struct rspamd_task *task = lua_check_task(L, 1);
-       const char *symbol;
-       gboolean found = FALSE;
+       struct rspamd_scan_result *sres = NULL;
 
-       symbol = luaL_checkstring(L, 2);
+       if (!task) {
+               return luaL_error(L, "invalid arguments");
+       }
 
-       if (task && symbol) {
-               struct rspamd_scan_result *sres = NULL;
+       /* Table form: get_symbol({names}, [shadow_result_name]) */
+       if (lua_istable(L, 2)) {
+               unsigned int count = 0;
 
-               if (lua_isstring(L, 3)) {
-                       sres = rspamd_find_metric_result(task, lua_tostring(L, 3));
+               if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+                       return 0;
+               }
+
+               lua_createtable(L, 0, 4);
 
-                       if (sres == NULL) {
-                               return luaL_error(L, "invalid scan result: %s",
-                                                                 lua_tostring(L, 3));
+               lua_pushnil(L);
+               while (lua_next(L, 2) != 0) {
+                       if (lua_type(L, -1) == LUA_TSTRING) {
+                               const char *name = lua_tostring(L, -1);
+                               lua_task_symbol_push_into_map(L, task, name, NULL, sres, &count);
                        }
+                       lua_pop(L, 1);
+               }
+
+               if (count == 0) {
+                       lua_pop(L, 1);
+                       lua_pushnil(L);
+               }
+
+               return 1;
+       }
+
+       /* Legacy single-name form: get_symbol(name [, shadow_result_name]) */
+       if (lua_type(L, 2) == LUA_TSTRING) {
+               const char *symbol = lua_tostring(L, 2);
+               gboolean found = FALSE;
+
+               if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+                       return 0;
                }
 
                /* Always push as a table for compatibility :( */
@@ -5286,54 +5394,155 @@ lua_task_get_symbol(lua_State *L)
                        lua_rawseti(L, -2, 1);
                }
                else {
-                       /* Pop table */
                        lua_pop(L, 1);
+                       lua_pushnil(L);
                }
+
+               return 1;
        }
-       else {
+
+       return luaL_error(L, "invalid arguments");
+}
+
+static int
+lua_task_get_symbol_regexp(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_task *task = lua_check_task(L, 1);
+       struct rspamd_lua_regexp *re = lua_check_regexp(L, 2);
+       struct rspamd_scan_result *sres = NULL;
+       struct rspamd_symbol_result *s;
+       unsigned int count = 0;
+
+       if (!task || !re || !re->re) {
                return luaL_error(L, "invalid arguments");
        }
 
-       if (!found) {
+       if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+               return 0;
+       }
+
+       if (!sres) {
+               sres = task->result;
+       }
+
+       if (!sres) {
+               lua_pushnil(L);
+               return 1;
+       }
+
+       lua_createtable(L, 0, 4);
+
+       kh_foreach_value(sres->symbols, s, {
+               if (!(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED) && s->name) {
+                       if (rspamd_regexp_match(re->re, s->name, strlen(s->name), FALSE)) {
+                               lua_task_symbol_push_into_map(L, task, s->name, s, sres, &count);
+                       }
+               }
+       });
+
+       if (count == 0) {
+               lua_pop(L, 1);
                lua_pushnil(L);
        }
 
        return 1;
 }
 
+static inline gboolean
+lua_task_check_single_symbol(struct rspamd_task *task,
+                                                        const char *symbol,
+                                                        struct rspamd_scan_result *sres)
+{
+       struct rspamd_symbol_result *s = rspamd_task_find_symbol_result(task, symbol, sres);
+
+       return (s != NULL && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED));
+}
+
 static int
 lua_task_has_symbol(lua_State *L)
 {
        LUA_TRACE_POINT;
        struct rspamd_task *task = lua_check_task(L, 1);
-       struct rspamd_symbol_result *s;
-       const char *symbol;
+       struct rspamd_scan_result *sres = NULL;
        gboolean found = FALSE;
 
-       symbol = luaL_checkstring(L, 2);
+       if (!task) {
+               return luaL_error(L, "invalid arguments");
+       }
 
-       if (task && symbol) {
-               if (lua_isstring(L, 3)) {
-                       s = rspamd_task_find_symbol_result(task, symbol,
-                                                                                          rspamd_find_metric_result(task, lua_tostring(L, 3)));
+       /* Table form: has_symbol({names}, [shadow_result_name]) */
+       if (lua_istable(L, 2)) {
+               if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+                       return 0;
+               }
 
-                       if (s && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) {
-                               found = TRUE;
+               lua_pushnil(L);
+               while (lua_next(L, 2) != 0) {
+                       if (lua_type(L, -1) == LUA_TSTRING) {
+                               if (lua_task_check_single_symbol(task, lua_tostring(L, -1), sres)) {
+                                       found = TRUE;
+                                       lua_pop(L, 2); /* value + key */
+                                       break;
+                               }
                        }
+                       lua_pop(L, 1);
                }
-               else {
-                       s = rspamd_task_find_symbol_result(task, symbol, NULL);
 
-                       if (s && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) {
-                               found = TRUE;
-                       }
+               lua_pushboolean(L, found);
+               return 1;
+       }
+
+       /* Legacy single-name form: has_symbol(name [, shadow_result_name]) */
+       if (lua_type(L, 2) == LUA_TSTRING) {
+               const char *symbol = lua_tostring(L, 2);
+
+               if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+                       return 0;
                }
+
+               found = lua_task_check_single_symbol(task, symbol, sres);
                lua_pushboolean(L, found);
+               return 1;
        }
-       else {
+
+       return luaL_error(L, "invalid arguments");
+}
+
+static int
+lua_task_has_symbol_regexp(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_task *task = lua_check_task(L, 1);
+       struct rspamd_lua_regexp *re = lua_check_regexp(L, 2);
+       struct rspamd_scan_result *sres = NULL;
+       struct rspamd_symbol_result *s;
+       gboolean found = FALSE;
+
+       if (!task || !re || !re->re) {
                return luaL_error(L, "invalid arguments");
        }
 
+       if (!lua_task_resolve_shadow_result(L, task, 3, &sres)) {
+               return 0;
+       }
+
+       if (!sres) {
+               sres = task->result;
+       }
+
+       if (sres) {
+               kh_foreach_value(sres->symbols, s, {
+                       if (!(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED) && s->name) {
+                               if (rspamd_regexp_match(re->re, s->name, strlen(s->name), FALSE)) {
+                                       found = TRUE;
+                                       break;
+                               }
+                       }
+               });
+       }
+
+       lua_pushboolean(L, found);
        return 1;
 }