]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add symbols proxy for piecewise changes
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 14 Sep 2025 09:02:41 +0000 (10:02 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 14 Sep 2025 09:02:41 +0000 (10:02 +0100)
src/lua/lua_common.c
src/lua/lua_config.c

index f3622868027f4c74c3abc4a6c313c68e35182e59..7b14a85c2f8324c5d94b0a2dc2ccfd583669b3e6 100644 (file)
@@ -105,6 +105,7 @@ void rspamd_lua_new_class(lua_State *L,
        lua_createtable(L, 0, 3 + nmethods);
 
        if (!seen_index) {
+               /* Default __index = metatable only for plain classes without custom __index */
                lua_pushstring(L, "__index");
                lua_pushvalue(L, -2); /* pushes the metatable */
                lua_settable(L, -3);  /* metatable.__index = metatable */
index 215054b1ca5071dbd5959be9a6d0b53538f11b55..2954bd1390853fb1805139483b3010bc602dc9de 100644 (file)
@@ -1002,6 +1002,18 @@ LUA_FUNCTION_DEF(config, load_custom_tokenizers);
  */
 LUA_FUNCTION_DEF(config, unload_custom_tokenizers);
 
+/* Symbol proxy (for piecewise symbol updates via rspamd_config.__index) */
+LUA_FUNCTION_DEF(config, index);
+LUA_FUNCTION_DEF(symbol_proxy, index);
+LUA_FUNCTION_DEF(symbol_proxy, newindex);
+
+struct lua_symbol_proxy {
+       struct rspamd_config *cfg;
+       const char *name;
+};
+
+static const char *rspamd_config_symbol_proxy_classname = "rspamd{config_symbol_proxy}";
+
 static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF(config, get_module_opt),
        LUA_INTERFACE_DEF(config, get_mempool),
@@ -1092,6 +1104,7 @@ static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF(config, get_dns_timeout),
        LUA_INTERFACE_DEF(config, load_custom_tokenizers),
        LUA_INTERFACE_DEF(config, unload_custom_tokenizers),
+       {"__index", lua_config_index},
        {"__tostring", rspamd_lua_class_tostring},
        {"__newindex", lua_config_newindex},
        {NULL, NULL}};
@@ -1109,6 +1122,12 @@ static const struct luaL_reg monitoredlib_m[] = {
        {"__tostring", rspamd_lua_class_tostring},
        {NULL, NULL}};
 
+static const struct luaL_reg symbol_proxylib_m[] = {
+       {"__index", lua_symbol_proxy_index},
+       {"__newindex", lua_symbol_proxy_newindex},
+       {"__tostring", rspamd_lua_class_tostring},
+       {NULL, NULL}};
+
 static const uint64_t rspamd_lua_callback_magic = 0x32c118af1e3263c7ULL;
 
 struct rspamd_config *
@@ -3096,6 +3115,355 @@ lua_config_newindex(lua_State *L)
        return 0;
 }
 
+static int
+lua_config_index(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_config *cfg = lua_check_config(L, 1);
+       const char *key = NULL;
+
+       if (cfg == NULL) {
+               return luaL_error(L, "invalid arguments");
+       }
+
+       if (lua_type(L, 2) == LUA_TSTRING) {
+               key = lua_tostring(L, 2);
+
+               /* First, try to find a method/field on the metatable (preserve existing API) */
+               if (lua_getmetatable(L, 1)) {
+                       /* stack: obj, key, mt */
+                       lua_pushvalue(L, 2); /* push key */
+                       lua_rawget(L, -2);   /* mt[key] */
+
+                       if (!lua_isnil(L, -1)) {
+                               /* Found method/field */
+                               return 1;
+                       }
+
+                       /* Not found: pop nil and metatable */
+                       lua_pop(L, 2);
+               }
+
+               /* Return symbol proxy userdata for piecewise updates */
+               struct lua_symbol_proxy *proxy = lua_newuserdata(L, sizeof(*proxy));
+               proxy->cfg = cfg;
+               proxy->name = rspamd_mempool_strdup(cfg->cfg_pool, key);
+               rspamd_lua_setclass(L, rspamd_config_symbol_proxy_classname, -1);
+
+               return 1;
+       }
+
+       /* Non-string keys are not supported here; preserve legacy behaviour: nil */
+       lua_pushnil(L);
+       return 1;
+}
+
+static int
+lua_symbol_proxy_index(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct lua_symbol_proxy *sp = (struct lua_symbol_proxy *) rspamd_lua_check_udata(L, 1,
+                                                                                                                                                                        rspamd_config_symbol_proxy_classname);
+       const char *field = NULL;
+
+       if (sp == NULL || sp->cfg == NULL) {
+               return luaL_error(L, "invalid symbol proxy");
+       }
+
+       if (lua_type(L, 2) != LUA_TSTRING) {
+               lua_pushnil(L);
+               return 1;
+       }
+
+       field = lua_tostring(L, 2);
+
+       if (g_ascii_strcasecmp(field, "name") == 0) {
+               lua_pushstring(L, sp->name);
+               return 1;
+       }
+
+       /* Try to extract basic metric data */
+       struct rspamd_symbol *sym = g_hash_table_lookup(sp->cfg->symbols, sp->name);
+
+       if (sym == NULL) {
+               lua_pushnil(L);
+               return 1;
+       }
+
+       if (g_ascii_strcasecmp(field, "score") == 0) {
+               lua_pushnumber(L, sym->score);
+               return 1;
+       }
+       else if (g_ascii_strcasecmp(field, "description") == 0) {
+               if (sym->description) {
+                       lua_pushstring(L, sym->description);
+               }
+               else {
+                       lua_pushnil(L);
+               }
+               return 1;
+       }
+       else if (g_ascii_strcasecmp(field, "group") == 0) {
+               if (sym->gr && sym->gr->name) {
+                       lua_pushstring(L, sym->gr->name);
+               }
+               else {
+                       lua_pushnil(L);
+               }
+               return 1;
+       }
+       else if (g_ascii_strcasecmp(field, "nshots") == 0) {
+               lua_pushinteger(L, sym->nshots);
+               return 1;
+       }
+
+       /* Unknown field */
+       lua_pushnil(L);
+       return 1;
+}
+
+static int
+lua_symbol_proxy_newindex(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct lua_symbol_proxy *sp = (struct lua_symbol_proxy *) rspamd_lua_check_udata(L, 1,
+                                                                                                                                                                        rspamd_config_symbol_proxy_classname);
+       const char *field = NULL;
+
+       if (sp == NULL || sp->cfg == NULL) {
+               return luaL_error(L, "invalid symbol proxy");
+       }
+
+       if (lua_type(L, 2) != LUA_TSTRING) {
+               return luaL_error(L, "invalid field type: %s", lua_typename(L, lua_type(L, 2)));
+       }
+
+       field = lua_tostring(L, 2);
+
+       /* Special: condition */
+       if (g_ascii_strcasecmp(field, "condition") == 0) {
+               if (lua_type(L, 3) == LUA_TFUNCTION) {
+                       lua_pushvalue(L, 3);
+                       int condref = luaL_ref(L, LUA_REGISTRYINDEX);
+                       gboolean ret = rspamd_symcache_add_condition_delayed(sp->cfg->cache, sp->name, L, condref);
+
+                       if (!ret) {
+                               luaL_unref(L, LUA_REGISTRYINDEX, condref);
+                       }
+
+                       return 0;
+               }
+
+               /* Ignore non-function */
+               return 0;
+       }
+
+       /* Special: callback - set or create symbol */
+       if (g_ascii_strcasecmp(field, "callback") == 0) {
+               if (lua_type(L, 3) == LUA_TFUNCTION) {
+                       int id = rspamd_symcache_find_symbol(sp->cfg->cache, sp->name);
+
+                       if (id == -1) {
+                               /* Register new symbol with default params */
+                               lua_pushvalue(L, 3);
+                               int cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+                               (void) rspamd_register_symbol_fromlua(L,
+                                                                                                         sp->cfg,
+                                                                                                         sp->name,
+                                                                                                         cbref,
+                                                                                                         1.0,
+                                                                                                         0,
+                                                                                                         SYMBOL_TYPE_NORMAL,
+                                                                                                         -1,
+                                                                                                         NULL, NULL,
+                                                                                                         FALSE);
+                       }
+                       else {
+                               /* Existing symbol: replace callback */
+                               struct rspamd_abstract_callback_data *abs_cbdata;
+                               struct lua_callback_data *cbd;
+
+                               abs_cbdata = rspamd_symcache_get_cbdata(sp->cfg->cache, sp->name);
+
+                               if (abs_cbdata != NULL && abs_cbdata->magic == rspamd_lua_callback_magic) {
+                                       cbd = (struct lua_callback_data *) abs_cbdata;
+                                       if (cbd->cb_is_ref) {
+                                               luaL_unref(L, LUA_REGISTRYINDEX, cbd->callback.ref);
+                                       }
+                                       else {
+                                               cbd->cb_is_ref = TRUE;
+                                       }
+                                       lua_pushvalue(L, 3);
+                                       cbd->callback.ref = luaL_ref(L, LUA_REGISTRYINDEX);
+                               }
+                       }
+
+                       return 0;
+               }
+
+               return luaL_error(L, "callback must be a function");
+       }
+
+       /* Metric fields: score, description, group, nshots, one_shot, one_param, flags, priority, groups */
+       struct rspamd_symbol *sym = g_hash_table_lookup(sp->cfg->symbols, sp->name);
+       const char *group = NULL;
+       const char *description = NULL;
+       unsigned int priority = 0;
+       unsigned int flags = 0;
+       int nshots = 0; /* 0 means keep default unless specified */
+       double score = NAN;
+
+       if (sym) {
+               group = sym->gr ? sym->gr->name : NULL;
+               description = sym->description;
+               priority = sym->priority;
+               flags = sym->flags;
+               nshots = sym->nshots;
+       }
+       else {
+               /* defaults */
+               group = NULL;
+               description = NULL;
+               priority = 0;
+               flags = 0;
+               nshots = sp->cfg->default_max_shots;
+       }
+
+       if (g_ascii_strcasecmp(field, "score") == 0) {
+               score = luaL_checknumber(L, 3);
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "description") == 0) {
+               description = luaL_checkstring(L, 3);
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "group") == 0) {
+               group = luaL_checkstring(L, 3);
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "nshots") == 0) {
+               nshots = luaL_checkinteger(L, 3);
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "one_shot") == 0) {
+               if (lua_toboolean(L, 3)) {
+                       nshots = 1;
+               }
+               else if (sym) {
+                       /* keep existing */
+                       nshots = sym->nshots;
+               }
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "one_param") == 0) {
+               if (lua_toboolean(L, 3)) {
+                       flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM;
+               }
+               else {
+                       flags &= ~RSPAMD_SYMBOL_FLAG_ONEPARAM;
+               }
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "flags") == 0) {
+               /* Support a subset: ignore, one_param; one_shot handled via nshots */
+               if (lua_type(L, 3) == LUA_TSTRING) {
+                       const char *fls = lua_tostring(L, 3);
+                       if (strstr(fls, "ignore") != NULL) {
+                               flags |= RSPAMD_SYMBOL_FLAG_IGNORE_METRIC;
+                       }
+                       if (strstr(fls, "one_param") != NULL) {
+                               flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM;
+                       }
+               }
+               else if (lua_type(L, 3) == LUA_TTABLE) {
+                       for (lua_pushnil(L); lua_next(L, 3); lua_pop(L, 1)) {
+                               if (lua_type(L, -1) == LUA_TSTRING) {
+                                       const char *fl = lua_tostring(L, -1);
+                                       if (g_ascii_strcasecmp(fl, "ignore") == 0) {
+                                               flags |= RSPAMD_SYMBOL_FLAG_IGNORE_METRIC;
+                                       }
+                                       else if (g_ascii_strcasecmp(fl, "one_param") == 0) {
+                                               flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM;
+                                       }
+                               }
+                       }
+               }
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "priority") == 0) {
+               priority = luaL_checkinteger(L, 3);
+               rspamd_config_add_symbol(sp->cfg, sp->name, score,
+                                                                description, group, flags, priority, nshots);
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "groups") == 0) {
+               if (lua_type(L, 3) == LUA_TTABLE) {
+                       for (lua_pushnil(L); lua_next(L, 3); lua_pop(L, 1)) {
+                               if (lua_type(L, -1) == LUA_TSTRING) {
+                                       rspamd_config_add_symbol_group(sp->cfg, sp->name,
+                                                                                                  lua_tostring(L, -1));
+                               }
+                       }
+               }
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "augmentations") == 0) {
+               if (lua_type(L, 3) == LUA_TTABLE) {
+                       int id = rspamd_symcache_find_symbol(sp->cfg->cache, sp->name);
+                       if (id != -1) {
+                               for (lua_pushnil(L); lua_next(L, 3); lua_pop(L, 1)) {
+                                       if (lua_type(L, -1) == LUA_TSTRING) {
+                                               rspamd_symcache_add_symbol_augmentation(sp->cfg->cache, id,
+                                                                                                                               lua_tostring(L, -1), NULL);
+                                       }
+                               }
+                       }
+               }
+               return 0;
+       }
+       else if (g_ascii_strcasecmp(field, "allowed_ids") == 0 ||
+                        g_ascii_strcasecmp(field, "forbidden_ids") == 0) {
+               if (lua_type(L, 3) == LUA_TTABLE) {
+                       unsigned int len = rspamd_lua_table_size(L, 3);
+                       GArray *ids = g_array_sized_new(FALSE, FALSE, sizeof(uint32_t), len);
+                       for (lua_pushnil(L); lua_next(L, 3); lua_pop(L, 1)) {
+                               uint32_t v = lua_tointeger(L, -1);
+                               g_array_append_val(ids, v);
+                       }
+
+                       if (ids->len > 0) {
+                               if (g_ascii_strcasecmp(field, "allowed_ids") == 0) {
+                                       rspamd_symcache_set_allowed_settings_ids(sp->cfg->cache, sp->name,
+                                                                                                                        &g_array_index(ids, uint32_t, 0), ids->len);
+                               }
+                               else {
+                                       rspamd_symcache_set_forbidden_settings_ids(sp->cfg->cache, sp->name,
+                                                                                                                          &g_array_index(ids, uint32_t, 0), ids->len);
+                               }
+                       }
+
+                       g_array_free(ids, TRUE);
+               }
+               return 0;
+       }
+
+       /* Unknown field: ignore */
+       return 0;
+}
 static int
 lua_config_add_condition(lua_State *L)
 {
@@ -4940,6 +5308,10 @@ void luaopen_config(lua_State *L)
 
        lua_pop(L, 1);
 
+       /* Register symbol proxy class */
+       rspamd_lua_new_class(L, rspamd_config_symbol_proxy_classname, symbol_proxylib_m);
+       lua_pop(L, 1);
+
        /* Export constants */
        lua_pushinteger(L, RSPAMD_RE_CACHE_FLAG_LOADED);
        lua_setglobal(L, "RSPAMD_RE_CACHE_FLAG_LOADED");