From: Vsevolod Stakhov Date: Sun, 17 Apr 2022 14:10:09 +0000 (+0100) Subject: [Project] Implement validation logic X-Git-Tag: 3.3~293^2~30 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=330cb456def3c55ecc1928c4d9b9b2b93c739c4d;p=thirdparty%2Frspamd.git [Project] Implement validation logic --- diff --git a/src/libserver/rspamd_symcache.h b/src/libserver/rspamd_symcache.h index 303544d7b1..915da9b155 100644 --- a/src/libserver/rspamd_symcache.h +++ b/src/libserver/rspamd_symcache.h @@ -171,15 +171,6 @@ gboolean rspamd_symcache_stat_symbol (struct rspamd_symcache *cache, gdouble *tm, guint *nhits); -/** - * Find symbol in cache by its id - * @param cache - * @param id - * @return symbol's name or NULL - */ -const gchar *rspamd_symcache_symbol_by_id (struct rspamd_symcache *cache, - gint id); - /** * Returns number of symbols registered in symbols cache * @param cache @@ -187,16 +178,6 @@ const gchar *rspamd_symcache_symbol_by_id (struct rspamd_symcache *cache, */ guint rspamd_symcache_stats_symbols_count (struct rspamd_symcache *cache); -/** - * Call function for cached symbol using saved callback - * @param task task object - * @param cache symbols cache - * @param saved_item pointer to currently saved item - */ -gboolean rspamd_symcache_process_symbols (struct rspamd_task *task, - struct rspamd_symcache *cache, - gint stage); - /** * Validate cache items against theirs weights defined in metrics * @param cache symbols cache @@ -207,6 +188,16 @@ gboolean rspamd_symcache_validate (struct rspamd_symcache *cache, struct rspamd_config *cfg, gboolean strict); +/** + * Call function for cached symbol using saved callback + * @param task task object + * @param cache symbols cache + * @param saved_item pointer to currently saved item + */ +gboolean rspamd_symcache_process_symbols (struct rspamd_task *task, + struct rspamd_symcache *cache, + gint stage); + /** * Return statistics about the cache as ucl object (array of objects one per item) * @param cache diff --git a/src/libserver/symcache/symcache_c.cxx b/src/libserver/symcache/symcache_c.cxx index 88a0d5605b..d081d7841f 100644 --- a/src/libserver/symcache/symcache_c.cxx +++ b/src/libserver/symcache/symcache_c.cxx @@ -24,23 +24,23 @@ #define C_API_SYMCACHE_ITEM(ptr) (reinterpret_cast(ptr)) void -rspamd_symcache_destroy (struct rspamd_symcache *cache) +rspamd_symcache_destroy(struct rspamd_symcache *cache) { auto *real_cache = C_API_SYMCACHE(cache); delete real_cache; } -struct rspamd_symcache* -rspamd_symcache_new (struct rspamd_config *cfg) +struct rspamd_symcache * +rspamd_symcache_new(struct rspamd_config *cfg) { auto *ncache = new rspamd::symcache::symcache(cfg); - return (struct rspamd_symcache*)ncache; + return (struct rspamd_symcache *) ncache; } gboolean -rspamd_symcache_init (struct rspamd_symcache *cache) +rspamd_symcache_init(struct rspamd_symcache *cache) { auto *real_cache = C_API_SYMCACHE(cache); @@ -48,7 +48,7 @@ rspamd_symcache_init (struct rspamd_symcache *cache) } void -rspamd_symcache_save (struct rspamd_symcache *cache) +rspamd_symcache_save(struct rspamd_symcache *cache) { auto *real_cache = C_API_SYMCACHE(cache); @@ -56,13 +56,13 @@ rspamd_symcache_save (struct rspamd_symcache *cache) } gint -rspamd_symcache_add_symbol (struct rspamd_symcache *cache, - const gchar *name, - gint priority, - symbol_func_t func, - gpointer user_data, - enum rspamd_symbol_type type, - gint parent) +rspamd_symcache_add_symbol(struct rspamd_symcache *cache, + const gchar *name, + gint priority, + symbol_func_t func, + gpointer user_data, + enum rspamd_symbol_type type, + gint parent) { auto *real_cache = C_API_SYMCACHE(cache); @@ -77,7 +77,7 @@ rspamd_symcache_add_symbol (struct rspamd_symcache *cache, } void -rspamd_symcache_set_peak_callback (struct rspamd_symcache *cache, gint cbref) +rspamd_symcache_set_peak_callback(struct rspamd_symcache *cache, gint cbref) { auto *real_cache = C_API_SYMCACHE(cache); @@ -85,8 +85,8 @@ rspamd_symcache_set_peak_callback (struct rspamd_symcache *cache, gint cbref) } gboolean -rspamd_symcache_add_condition_delayed (struct rspamd_symcache *cache, - const gchar *sym, lua_State *L, gint cbref) +rspamd_symcache_add_condition_delayed(struct rspamd_symcache *cache, + const gchar *sym, lua_State *L, gint cbref) { auto *real_cache = C_API_SYMCACHE(cache); @@ -95,8 +95,8 @@ rspamd_symcache_add_condition_delayed (struct rspamd_symcache *cache, return TRUE; } -gint rspamd_symcache_find_symbol (struct rspamd_symcache *cache, - const gchar *name) +gint rspamd_symcache_find_symbol(struct rspamd_symcache *cache, + const gchar *name) { auto *real_cache = C_API_SYMCACHE(cache); @@ -109,12 +109,12 @@ gint rspamd_symcache_find_symbol (struct rspamd_symcache *cache, return -1; } -gboolean rspamd_symcache_stat_symbol (struct rspamd_symcache *cache, - const gchar *name, - gdouble *frequency, - gdouble *freq_stddev, - gdouble *tm, - guint *nhits) +gboolean rspamd_symcache_stat_symbol(struct rspamd_symcache *cache, + const gchar *name, + gdouble *frequency, + gdouble *freq_stddev, + gdouble *tm, + guint *nhits) { auto *real_cache = C_API_SYMCACHE(cache); @@ -137,9 +137,25 @@ gboolean rspamd_symcache_stat_symbol (struct rspamd_symcache *cache, guint -rspamd_symcache_stats_symbols_count (struct rspamd_symcache *cache) +rspamd_symcache_stats_symbols_count(struct rspamd_symcache *cache) { auto *real_cache = C_API_SYMCACHE(cache); return real_cache->get_stats_symbols_count(); } +guint64 +rspamd_symcache_get_cksum(struct rspamd_symcache *cache) +{ + auto *real_cache = C_API_SYMCACHE(cache); + return real_cache->get_cksum(); +} + +gboolean +rspamd_symcache_validate(struct rspamd_symcache *cache, + struct rspamd_config *cfg, + gboolean strict) +{ + auto *real_cache = C_API_SYMCACHE(cache); + + return real_cache->validate(strict); +} \ No newline at end of file diff --git a/src/libserver/symcache/symcache_impl.cxx b/src/libserver/symcache/symcache_impl.cxx index 2123508e64..74d7db7248 100644 --- a/src/libserver/symcache/symcache_impl.cxx +++ b/src/libserver/symcache/symcache_impl.cxx @@ -36,15 +36,15 @@ auto symcache::init() -> bool } /* Deal with the delayed dependencies */ - for (const auto &delayed_dep : *delayed_deps) { + for (const auto &delayed_dep: *delayed_deps) { auto virt_item = get_item_by_name(delayed_dep.from, false); auto real_item = get_item_by_name(delayed_dep.from, true); if (virt_item == nullptr || real_item == nullptr) { msg_err_cache("cannot register delayed dependency between %s and %s: " - "%s is missing", - delayed_dep.from.data(), - delayed_dep.to.data(), delayed_dep.from.data()); + "%s is missing", + delayed_dep.from.data(), + delayed_dep.to.data(), delayed_dep.from.data()); } else { msg_debug_cache("delayed between %s(%d:%d) -> %s", @@ -61,7 +61,7 @@ auto symcache::init() -> bool /* Deal with the delayed conditions */ - for (const auto &delayed_cond : *delayed_conditions) { + for (const auto &delayed_cond: *delayed_conditions) { auto it = get_item_by_name_mut(delayed_cond.sym, true); if (it == nullptr) { @@ -81,17 +81,17 @@ auto symcache::init() -> bool } delayed_conditions.reset(); - for (auto &it : items_by_id) { + for (auto &it: items_by_id) { it->process_deps(*this); } - for (auto &it : virtual_symbols) { + for (auto &it: virtual_symbols) { it->process_deps(*this); } /* Sorting stuff */ auto postfilters_cmp = [](const auto &it1, const auto &it2) -> int { - if (it1->priority > it2-> priority) { + if (it1->priority > it2->priority) { return 1; } else if (it1->priority == it2->priority) { @@ -101,7 +101,7 @@ auto symcache::init() -> bool return -1; }; auto prefilters_cmp = [](const auto &it1, const auto &it2) -> int { - if (it1->priority > it2-> priority) { + if (it1->priority > it2->priority) { return -1; } else if (it1->priority == it2->priority) { @@ -123,7 +123,7 @@ auto symcache::init() -> bool if (cfg->symbols) { g_hash_table_foreach(cfg->symbols, symcache::metric_connect_cb, - (void *)this); + (void *) this); } return res; @@ -286,7 +286,7 @@ bool symcache::save_items() const auto *top = ucl_object_typed_new(UCL_OBJECT); - for (const auto &it : items_by_symbol) { + for (const auto &it: items_by_symbol) { auto item = it.second; auto elt = ucl_object_typed_new(UCL_OBJECT); ucl_object_insert_key(elt, @@ -322,15 +322,15 @@ bool symcache::save_items() const auto symcache::metric_connect_cb(void *k, void *v, void *ud) -> void { - auto *cache = (symcache *)ud; - const auto *sym = (const char *)k; - auto *s = (struct rspamd_symbol *)v; + auto *cache = (symcache *) ud; + const auto *sym = (const char *) k; + auto *s = (struct rspamd_symbol *) v; auto weight = *s->weight_ptr; auto *item = cache->get_item_by_name_mut(sym, false); if (item) { item->st->weight = weight; - s->cache_item = (void *)item; + s->cache_item = (void *) item; } } @@ -339,7 +339,7 @@ auto symcache::get_item_by_id(int id, bool resolve_parent) const -> const cache_ { if (id < 0 || id >= items_by_id.size()) { msg_err_cache("internal error: requested item with id %d, when we have just %d items in the cache", - id, (int)items_by_id.size()); + id, (int) items_by_id.size()); return nullptr; } @@ -388,9 +388,9 @@ auto symcache::get_item_by_name_mut(std::string_view name, bool resolve_parent) return it->second.get(); } -auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_from)-> void +auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_from) -> void { - g_assert (id_from >= 0 && id_from < (gint)items_by_id.size()); + g_assert (id_from >= 0 && id_from < (gint) items_by_id.size()); const auto &source = items_by_id[id_from]; g_assert (source.get() != nullptr); @@ -401,7 +401,7 @@ auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_f if (virtual_id_from >= 0) { - g_assert (virtual_id_from < (gint)virtual_symbols.size()); + g_assert (virtual_id_from < (gint) virtual_symbols.size()); /* We need that for settings id propagation */ const auto &vsource = virtual_symbols[virtual_id_from]; g_assert (vsource.get() != nullptr); @@ -416,7 +416,7 @@ auto symcache::resort() -> void { auto ord = std::make_shared(filters.size(), cur_order_gen); - for (auto &it : filters) { + for (auto &it: filters) { total_hits += it->st->total_hits; it->order = 0; ord->d.emplace_back(it); @@ -471,7 +471,7 @@ auto symcache::resort() -> void tsort_mark(it, tsort_mask::TEMP); msg_debug_cache("visiting node: %s (%d)", it->symbol.c_str(), cur_order); - for (const auto &dep : it->deps) { + for (const auto &dep: it->deps) { msg_debug_cache ("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1); rec(dep.item.get(), cur_order + 1, rec); } @@ -484,7 +484,7 @@ auto symcache::resort() -> void */ total_hits = 0; - for (const auto &it : filters) { + for (const auto &it: filters) { if (it->order == 0) { tsort_visit(it.get(), 0, tsort_visit); } @@ -516,7 +516,8 @@ auto symcache::resort() -> void auto t2 = it2->st->avg_time; w1 = score_functor(weight1, f1, t1); w2 = score_functor(weight2, f2, t2); - } else { + } + else { /* Strict sorting */ w1 = std::abs(it1->priority); w2 = std::abs(it2->priority); @@ -560,7 +561,7 @@ auto symcache::add_symbol_with_callback(std::string_view name, if (real_type_pair.first != symcache_item_type::FILTER) { real_type_pair.second |= SYMBOL_TYPE_NOSTAT; } - if (real_type_pair.second & (SYMBOL_TYPE_GHOST|SYMBOL_TYPE_CALLBACK)) { + if (real_type_pair.second & (SYMBOL_TYPE_GHOST | SYMBOL_TYPE_CALLBACK)) { real_type_pair.second |= SYMBOL_TYPE_NOSTAT; } @@ -572,7 +573,7 @@ auto symcache::add_symbol_with_callback(std::string_view name, std::string static_string_name; if (name.empty()) { - static_string_name = fmt::format("AUTO_{}", (void *)func); + static_string_name = fmt::format("AUTO_{}", (void *) func); } else { static_string_name = name; @@ -596,7 +597,7 @@ auto symcache::add_symbol_with_callback(std::string_view name, if (!(real_type_pair.second & SYMBOL_TYPE_NOSTAT)) { cksum = t1ha(name.data(), name.size(), cksum); - stats_symbols_count ++; + stats_symbols_count++; } return id; @@ -648,7 +649,116 @@ auto symcache::set_peak_cb(int cbref) -> void auto symcache::add_delayed_condition(std::string_view sym, int cbref) -> void { - delayed_conditions->emplace_back(sym, cbref, (lua_State *)cfg->lua_state); + delayed_conditions->emplace_back(sym, cbref, (lua_State *) cfg->lua_state); +} + +auto symcache::validate(bool strict) -> bool +{ + total_weight = 1.0; + + for (auto &pair: items_by_symbol) { + auto &item = pair.second; + auto ghost = item->st->weight == 0 ? true : false; + auto skipped = !ghost; + + if (item->is_scoreable() && g_hash_table_lookup(cfg->symbols, item->symbol.c_str()) == nullptr) { + if (!isnan(cfg->unknown_weight)) { + item->st->weight = cfg->unknown_weight; + auto *s = rspamd_mempool_alloc0_type(static_pool, + struct rspamd_symbol); + /* Legit as we actually never modify this data */ + s->name = (char *) item->symbol.c_str(); + s->weight_ptr = &item->st->weight; + g_hash_table_insert(cfg->symbols, (void *) s->name, (void *) s); + + msg_info_cache ("adding unknown symbol %s with weight: %.2f", + item->symbol.c_str(), cfg->unknown_weight); + ghost = false; + skipped = false; + } + else { + skipped = true; + } + } + else { + skipped = FALSE; + } + + if (!ghost && skipped) { + if (!(item->flags & SYMBOL_TYPE_SKIPPED)) { + item->flags |= SYMBOL_TYPE_SKIPPED; + msg_warn_cache("symbol %s has no score registered, skip its check", + item->symbol.c_str()); + } + } + + if (ghost) { + msg_debug_cache ("symbol %s is registered as ghost symbol, it won't be inserted " + "to any metric", item->symbol.c_str()); + } + + if (item->st->weight < 0 && item->priority == 0) { + item->priority++; + } + + if (item->is_virtual()) { + if (!(item->flags & SYMBOL_TYPE_GHOST)) { + item->resolve_parent(*this); + auto *parent = const_cast(item->get_parent(*this)); + + if (::fabs(parent->st->weight) < ::fabs(item->st->weight)) { + parent->st->weight = item->st->weight; + } + + auto p1 = ::abs(item->priority); + auto p2 = ::abs(parent->priority); + + if (p1 != p2) { + parent->priority = MAX(p1, p2); + item->priority = parent->priority; + } + } + } + + total_weight += fabs(item->st->weight); + } + + /* Now check each metric item and find corresponding symbol in a cache */ + auto ret = true; + GHashTableIter it; + void *k, *v; + g_hash_table_iter_init(&it, cfg->symbols); + + while (g_hash_table_iter_next(&it, &k, &v)) { + auto ignore_symbol = false; + auto sym_def = (struct rspamd_symbol *)v; + + if (sym_def && (sym_def->flags & + (RSPAMD_SYMBOL_FLAG_IGNORE_METRIC | RSPAMD_SYMBOL_FLAG_DISABLED))) { + ignore_symbol = true; + } + + if (!ignore_symbol) { + if (!items_by_symbol.contains((const char *)k)) { + msg_warn_cache ( + "symbol '%s' has its score defined but there is no " + "corresponding rule registered", + k); + if (strict) { + ret = FALSE; + } + } + } + else if (sym_def->flags & RSPAMD_SYMBOL_FLAG_DISABLED) { + auto item = get_item_by_name_mut((const char *)k, false); + + if (item) { + item->enabled = FALSE; + } + } + } + + return ret; } auto cache_item::get_parent(const symcache &cache) const -> const cache_item * @@ -665,9 +775,9 @@ auto cache_item::get_parent(const symcache &cache) const -> const cache_item * auto cache_item::process_deps(const symcache &cache) -> void { /* Allow logging macros to work */ - auto log_tag = [&](){ return cache.log_tag(); }; + auto log_tag = [&]() { return cache.log_tag(); }; - for (auto &dep : deps) { + for (auto &dep: deps) { msg_debug_cache ("process real dependency %s on %s", symbol.c_str(), dep.sym.c_str()); auto *dit = cache.get_item_by_name_mut(dep.sym, true); @@ -743,7 +853,8 @@ auto cache_item::process_deps(const symcache &cache) -> void msg_err_cache ("cannot add dependency on self: %s -> %s " "(resolved to %s)", symbol.c_str(), dep.sym.c_str(), dit->symbol.c_str()); - } else { + } + else { /* Create a reverse dep */ dit->rdeps.emplace_back(getptr(), dep.sym, id, -1); dep.item = dit->getptr(); @@ -764,12 +875,12 @@ auto cache_item::process_deps(const symcache &cache) -> void // Remove empty deps deps.erase(std::remove_if(std::begin(deps), std::end(deps), - [](const auto &dep){ return !dep.item; }), std::end(deps)); + [](const auto &dep) { return !dep.item; }), std::end(deps)); } auto cache_item::resolve_parent(const symcache &cache) -> bool { - auto log_tag = [&](){ return cache.log_tag(); }; + auto log_tag = [&]() { return cache.log_tag(); }; if (is_virtual()) { auto &virt = std::get(specific); @@ -807,7 +918,7 @@ auto virtual_item::resolve_parent(const symcache &cache) -> bool auto item_ptr = cache.get_item_by_id(parent_id, true); if (item_ptr) { - parent = const_cast(item_ptr)->getptr(); + parent = const_cast(item_ptr)->getptr(); return true; } @@ -817,17 +928,18 @@ auto virtual_item::resolve_parent(const symcache &cache) -> bool auto item_type_from_c(enum rspamd_symbol_type type) -> tl::expected, std::string> { - constexpr const auto trivial_types = SYMBOL_TYPE_CONNFILTER|SYMBOL_TYPE_PREFILTER - |SYMBOL_TYPE_POSTFILTER|SYMBOL_TYPE_IDEMPOTENT - |SYMBOL_TYPE_COMPOSITE|SYMBOL_TYPE_CLASSIFIER - |SYMBOL_TYPE_VIRTUAL; + constexpr const auto trivial_types = SYMBOL_TYPE_CONNFILTER | SYMBOL_TYPE_PREFILTER + | SYMBOL_TYPE_POSTFILTER | SYMBOL_TYPE_IDEMPOTENT + | SYMBOL_TYPE_COMPOSITE | SYMBOL_TYPE_CLASSIFIER + | SYMBOL_TYPE_VIRTUAL; constexpr auto all_but_one_ty = [&](int type, int exclude_bit) -> auto { return type & (trivial_types & ~exclude_bit); }; if (type & trivial_types) { - auto check_trivial = [&](auto flag, symcache_item_type ty) -> tl::expected, std::string> { + auto check_trivial = [&](auto flag, + symcache_item_type ty) -> tl::expected, std::string> { if (all_but_one_ty(type, flag)) { return tl::make_unexpected(fmt::format("invalid flags for a symbol: {}", type)); } @@ -866,7 +978,7 @@ auto item_type_from_c(enum rspamd_symbol_type type) -> tl::expected(specific) && (type == symcache_item_type::FILTER); } + auto is_scoreable() const -> bool { + return (type == symcache_item_type::FILTER) || + is_virtual() || + (type == symcache_item_type::COMPOSITE) || + (type == symcache_item_type::CLASSIFIER); + } auto is_ghost() const -> bool { return flags & SYMBOL_TYPE_GHOST; } @@ -519,6 +525,21 @@ public: auto get_stats_symbols_count() const { return stats_symbols_count; } + + /** + * Returns a checksum for the cache + * @return + */ + auto get_cksum() const { + return cksum; + } + + /** + * Validate symbols in the cache + * @param strict + * @return + */ + auto validate(bool strict) -> bool; }; /*