]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Support augmentations with values
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Aug 2022 19:44:35 +0000 (20:44 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Aug 2022 19:44:35 +0000 (20:44 +0100)
src/libserver/rspamd_symcache.h
src/libserver/symcache/symcache_c.cxx
src/libserver/symcache/symcache_item.cxx
src/libserver/symcache/symcache_item.hxx
src/lua/lua_config.c

index a7258143243087b6962d841dbad2ac0d71f83bde..ee3f4862a9c7bd96ea4e5fa5c33ea8d2c90ea6f9 100644 (file)
@@ -134,7 +134,9 @@ gint rspamd_symcache_add_symbol (struct rspamd_symcache *cache,
  * @return
  */
 bool rspamd_symcache_add_symbol_augmentation(struct rspamd_symcache *cache,
-               int sym_id, const char *augmentation);
+               int sym_id,
+               const char *augmentation,
+               const char *value);
 
 /**
  * Add callback to be executed whenever symbol has peak value
index 71e0057ee4585786a0e14463593e6076b4d0dba0..6ab1206c09f876dd409b453456857c6acc273fb9 100644 (file)
@@ -86,7 +86,9 @@ rspamd_symcache_add_symbol(struct rspamd_symcache *cache,
 
 bool
 rspamd_symcache_add_symbol_augmentation(struct rspamd_symcache *cache,
-                                                                                        int sym_id, const char *augmentation)
+                                                                               int sym_id,
+                                                                               const char *augmentation,
+                                                                               const char *value)
 {
        auto *real_cache = C_API_SYMCACHE(cache);
        auto log_tag = [&]() { return real_cache->log_tag(); };
@@ -104,7 +106,7 @@ rspamd_symcache_add_symbol_augmentation(struct rspamd_symcache *cache,
                return false;
        }
 
-       return item->add_augmentation(*real_cache, augmentation);
+       return item->add_augmentation(*real_cache, augmentation, value);
 }
 
 void
index 5f37de8a79c0ce34316c877c3e82ee87ca095d8a..c41d3d68c7156d6da2ca6f5f77daa7b629722ef5 100644 (file)
 
 namespace rspamd::symcache {
 
+enum class augmentation_value_type {
+       NO_VALUE,
+       STRING_VALUE,
+       NUMBER_VALUE,
+};
+
 struct augmentation_info {
        int weight = 0;
        int implied_flags = 0;
+       augmentation_value_type value_type = augmentation_value_type::NO_VALUE;
 };
 
 /* A list of internal augmentations that are known to Rspamd with their weight */
@@ -411,45 +418,85 @@ auto cache_item::is_allowed(struct rspamd_task *task, bool exec_only) const -> b
 }
 
 auto
-cache_item::add_augmentation(const symcache &cache, std::string_view augmentation) -> bool {
+cache_item::add_augmentation(const symcache &cache, std::string_view augmentation,
+                                                        std::optional<std::string_view> value) -> bool {
        auto log_tag = [&]() { return cache.log_tag(); };
 
        if (augmentations.contains(augmentation)) {
                msg_warn_cache("duplicate augmentation: %s", augmentation.data());
+
+               return false;
        }
 
-       augmentations.insert(std::string(augmentation));
+       auto maybe_known = rspamd::find_map(known_augmentations, augmentation);
 
-       auto ret = rspamd::find_map(known_augmentations, augmentation);
+       if (maybe_known.has_value()) {
+               auto &known_info = maybe_known.value().get();
 
-       msg_debug_cache("added %s augmentation %s for symbol %s",
-                       ret.has_value() ? "known" : "unknown", augmentation.data(), symbol.data());
+               if (known_info.implied_flags) {
+                       if ((known_info.implied_flags & flags) == 0) {
+                               msg_info_cache("added implied flags (%bd) for symbol %s as it has %s augmentation",
+                                               known_info.implied_flags, symbol.data(), augmentation.data());
+                               flags |= known_info.implied_flags;
+                       }
+               }
 
-       if (ret.has_value()) {
-               auto info = ret.value().get();
+               if (known_info.value_type == augmentation_value_type::NO_VALUE) {
+                       if (value.has_value()) {
+                               msg_err_cache("value specified for augmentation %s, that has no value",
+                                               augmentation.data());
 
-               if (info.implied_flags) {
-                       if ((info.implied_flags & flags) == 0) {
-                               msg_info_cache("added implied flags (%bd) for symbol %s as it has %s augmentation",
-                                               info.implied_flags, symbol.data(), augmentation.data());
-                               flags |= info.implied_flags;
+                               return false;
+                       }
+                       return augmentations.try_emplace(std::string{augmentation}, known_info.weight).second;
+               }
+               else {
+                       if (!value.has_value()) {
+                               msg_err_cache("value is not specified for augmentation %s, that requires explicit value",
+                                               augmentation.data());
+
+                               return false;
+                       }
+
+                       if (known_info.value_type == augmentation_value_type::STRING_VALUE) {
+                               return augmentations.try_emplace(std::string{augmentation}, std::string{value.value()},
+                                               known_info.weight).second;
+                       }
+                       else if (known_info.value_type == augmentation_value_type::NUMBER_VALUE) {
+                               /* I wish it was supported properly */
+                               //auto conv_res = std::from_chars(value->data(), value->size(), num);
+                               char numbuf[128], *endptr = nullptr;
+                               rspamd_strlcpy(numbuf, value->data(), MIN(value->size(), sizeof(numbuf)));
+                               auto num = g_ascii_strtod(numbuf, &endptr);
+
+                               if (fabs (num) >= G_MAXFLOAT || std::isnan(num)) {
+                                       msg_err_cache("value for augmentation %s is not numeric: %*s",
+                                                       augmentation.data(),
+                                                       (int)value->size(), value->data());
+                                       return false;
+                               }
+
+                               return augmentations.try_emplace(std::string{augmentation}, num,
+                                               known_info.weight).second;
                        }
                }
        }
+       else {
+               msg_debug_cache("added unknown augmentation %s for symbol %s",
+                               "unknown", augmentation.data(), symbol.data());
+               return augmentations.try_emplace(std::string{augmentation}, 0).second;
+       }
 
-       return ret.has_value();
+       // Should not be reached
+       return false;
 }
 
 auto
 cache_item::get_augmentation_weight() const -> int
 {
        return std::accumulate(std::begin(augmentations), std::end(augmentations),
-                                                 0, [](int acc, const std::string &augmentation) {
-                               auto default_augmentation_info = augmentation_info{};
-                               return acc + rspamd::find_map(known_augmentations, augmentation)
-                                               .value_or(default_augmentation_info)
-                                               .get()
-                                               .weight;
+                                                 0, [](int acc, const auto &map_pair) {
+                               return acc + map_pair.second.weight;
        });
 }
 
index 435a19abf2b5b103095063f1a0adc12ece67dbd5..31706058b50f4a208a95eb5870830173eacdd266 100644 (file)
@@ -174,8 +174,20 @@ public:
        }
 };
 
+/*
+ * Used to store augmentation values
+ */
+struct item_augmentation {
+       std::variant<std::monostate, std::string, double> value;
+       int weight;
+
+       explicit item_augmentation(int weight) : value(std::monostate{}), weight(weight) {}
+       explicit item_augmentation(std::string str_value, int weight) : value(str_value), weight(weight) {}
+       explicit item_augmentation(double double_value, int weight) : value(double_value), weight(weight) {}
+};
+
 struct cache_item : std::enable_shared_from_this<cache_item> {
-       /* This block is likely shared */
+       /* The following fields will live in shared memory */
        struct rspamd_symcache_item_stat *st = nullptr;
        struct rspamd_counter_data *cd = nullptr;
 
@@ -205,7 +217,8 @@ struct cache_item : std::enable_shared_from_this<cache_item> {
        id_list forbidden_ids;
 
        /* Set of augmentations */
-       ankerl::unordered_dense::set<std::string, rspamd::smart_str_hash, rspamd::smart_str_equal> augmentations;
+       ankerl::unordered_dense::map<std::string, item_augmentation,
+               rspamd::smart_str_hash, rspamd::smart_str_equal> augmentations;
 
        /* Dependencies */
        std::vector<cache_dependency> deps;
@@ -395,7 +408,8 @@ public:
         * @param augmentation
         * @return
         */
-       auto add_augmentation(const symcache &cache, std::string_view augmentation) -> bool;
+       auto add_augmentation(const symcache &cache, std::string_view augmentation,
+                                                 std::optional<std::string_view> value) -> bool;
 
        /**
         * Return sum weight of all known augmentations
index 333c8cdde12db448594217f426f2e436344a7fda..948aa165fa697340f96c188f7e30abb3ec9df33e 100644 (file)
@@ -2089,7 +2089,7 @@ lua_config_register_symbol (lua_State * L)
                                        const char *augmentation = lua_tostring(L, -1);
 
                                        if (!rspamd_symcache_add_symbol_augmentation(cfg->cache, ret,
-                                                       augmentation)) {
+                                                       augmentation, NULL)) {
                                                lua_settop(L, prev_top);
 
                                                return luaL_error (L, "unknown augmentation %s in symbol %s",
@@ -2739,7 +2739,7 @@ lua_config_newindex (lua_State *L)
                                        int tbl_idx = lua_gettop(L);
                                        for (lua_pushnil(L); lua_next(L, tbl_idx); lua_pop (L, 1)) {
                                                rspamd_symcache_add_symbol_augmentation(cfg->cache, id,
-                                                               lua_tostring(L, -1));
+                                                               lua_tostring(L, -1), NULL);
                                        }
                                }