]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Rework] Re-implement cache sorting
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 10 Apr 2022 10:09:51 +0000 (11:09 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 10 Apr 2022 10:09:51 +0000 (11:09 +0100)
src/libserver/cfg_file.h
src/libserver/symcache/symcache_impl.cxx
src/libserver/symcache/symcache_internal.hxx

index 7532639a7e718895fbd47657d5f0b7161490c439..18524af8dda6b3ebf450ff31c52b3f8533356b19 100644 (file)
@@ -146,7 +146,7 @@ struct rspamd_symbol {
        struct rspamd_symbols_group *gr; /* Main group */
        GPtrArray *groups; /* Other groups */
        guint flags;
-       struct rspamd_symcache_item *cache_item;
+       void *cache_item;
        gint nshots;
 };
 
index 11fd7b0e61d89ac9aa2ba0879b12fee5277d8441..aadd53b8f2d7af9888fd321ab50d9bfa6a17c9c9 100644 (file)
@@ -18,6 +18,8 @@
 #include "unix-std.h"
 #include "libutil/cxx/locked_file.hxx"
 
+#include <cmath>
+
 namespace rspamd::symcache {
 
 INIT_LOG_MODULE_PUBLIC(symcache)
@@ -113,13 +115,13 @@ auto symcache::init() -> bool
        std::stable_sort(std::begin(postfilters), std::end(postfilters), postfilters_cmp);
        std::stable_sort(std::begin(idempotent), std::end(idempotent), postfilters_cmp);
 
-       rspamd_symcache_resort(cache);
+       resort();
 
        /* Connect metric symbols with symcache symbols */
        if (cfg->symbols) {
                g_hash_table_foreach(cfg->symbols,
-                               rspamd_symcache_metric_connect_cb,
-                               this);
+                               symcache::metric_connect_cb,
+                               (void *)this);
        }
 
        return res;
@@ -316,6 +318,20 @@ bool symcache::save_items() const
        return ret;
 }
 
+auto symcache::metric_connect_cb(void *k, void *v, void *ud) -> void
+{
+       auto *cache = (symcache *)ud;
+       const auto *sym = (const char *)k;
+       auto *s = (struct rspamd_symbol *)v;
+       auto weight = *s->weight_ptr;
+       auto *item = cache->get_item_by_name_mut(sym, false);
+
+       if (item) {
+               item->st->weight = weight;
+               s->cache_item = (void *)item;
+       }
+}
+
 
 auto symcache::get_item_by_id(int id, bool resolve_parent) const -> const cache_item *
 {
@@ -394,6 +410,134 @@ auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_f
        }
 }
 
+auto symcache::resort() -> void
+{
+       auto ord = std::make_shared<order_generation>(filters.size(), cur_order_gen);
+
+       for (auto &it : filters) {
+               total_hits += it->st->total_hits;
+               it->order = 0;
+               ord->d.emplace_back(it);
+       }
+
+       enum class tsort_mask {
+               PERM,
+               TEMP
+       };
+
+       constexpr auto tsort_unmask = [](cache_item *it) -> auto {
+               return (it->order & ~((1u << 31) | (1u << 30)));
+       };
+
+       /* Recursive topological sort helper */
+       const auto tsort_visit = [&](cache_item *it, unsigned cur_order, auto &&rec) {
+               constexpr auto tsort_mark = [](cache_item *it, tsort_mask how) {
+                       switch (how) {
+                       case tsort_mask::PERM:
+                               it->order |= (1u << 31);
+                               break;
+                       case tsort_mask::TEMP:
+                               it->order |= (1u << 30);
+                               break;
+                       }
+               };
+               constexpr auto tsort_is_marked = [](cache_item *it, tsort_mask how) {
+                       switch (how) {
+                       case tsort_mask::PERM:
+                               return (it->order & (1u << 31));
+                       case tsort_mask::TEMP:
+                               return (it->order & (1u << 30));
+                       }
+               };
+
+               if (tsort_is_marked(it, tsort_mask::PERM)) {
+                       if (cur_order > tsort_unmask(it)) {
+                               /* Need to recalculate the whole chain */
+                               it->order = cur_order; /* That also removes all masking */
+                       }
+                       else {
+                               /* We are fine, stop DFS */
+                               return;
+                       }
+               }
+               else if (tsort_is_marked(it, tsort_mask::TEMP)) {
+                       msg_err_cache("cyclic dependencies found when checking '%s'!",
+                                       it->symbol.c_str());
+                       return;
+               }
+
+               tsort_mark(it, tsort_mask::TEMP);
+               msg_debug_cache("visiting node: %s (%d)", it->symbol.c_str(), cur_order);
+
+               for (const auto &dep : it->deps) {
+                       msg_debug_cache ("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1);
+                       rec(dep.item.get(), cur_order + 1, rec);
+               }
+
+               it->order = cur_order;
+               tsort_mark(it, tsort_mask::PERM);
+       };
+       /*
+        * Topological sort
+        */
+       total_hits = 0;
+
+       for (const auto &it : filters) {
+               if (it->order == 0) {
+                       tsort_visit(it.get(), 0, tsort_visit);
+               }
+       }
+
+
+       /* Main sorting comparator */
+       constexpr auto score_functor = [](auto w, auto f, auto t) -> auto {
+               auto time_alpha = 1.0, weight_alpha = 0.1, freq_alpha = 0.01;
+
+               return ((w > 0.0 ? w : weight_alpha) * (f > 0.0 ? f : freq_alpha) /
+                               (t > time_alpha ? t : time_alpha));
+       };
+
+       auto cache_order_cmp = [&](const auto &it1, const auto &it2) -> auto {
+               auto o1 = tsort_unmask(it1.get()), o2 = tsort_unmask(it2.get());
+               double w1 = 0., w2 = 0.;
+
+               if (o1 == o2) {
+                       /* No topological order */
+                       if (it1->priority == it2->priority) {
+                               auto avg_freq = ((double) total_hits / used_items);
+                               auto avg_weight = (total_weight / used_items);
+                               auto f1 = (double) it1->st->total_hits / avg_freq;
+                               auto f2 = (double) it2->st->total_hits / avg_freq;
+                               auto weight1 = std::fabs(it1->st->weight) / avg_weight;
+                               auto weight2 = std::fabs(it2->st->weight) / avg_weight;
+                               auto t1 = it1->st->avg_time;
+                               auto t2 = it2->st->avg_time;
+                               w1 = score_functor(weight1, f1, t1);
+                               w2 = score_functor(weight2, f2, t2);
+                       } else {
+                               /* Strict sorting */
+                               w1 = std::abs(it1->priority);
+                               w2 = std::abs(it2->priority);
+                       }
+               }
+               else {
+                       w1 = o1;
+                       w2 = o2;
+               }
+
+               if (w2 > w1) {
+                       return 1;
+               }
+               else if (w2 < w1) {
+                       return -1;
+               }
+
+               return 0;
+       };
+
+       std::stable_sort(std::begin(ord->d), std::end(ord->d), cache_order_cmp);
+       std::swap(ord, items_by_order);
+}
 
 
 auto cache_item::get_parent(const symcache &cache) const -> const cache_item *
index a2b852c1924a0395783978e05b4d3c06585f6c4e..7dd664e5ca95ca43c6f1cf096fa63056393af97e 100644 (file)
@@ -78,13 +78,16 @@ using cache_item_ptr = std::shared_ptr<cache_item>;
 using cache_item_weak_ptr = std::weak_ptr<cache_item>;
 
 struct order_generation {
-       std::vector<cache_item_weak_ptr> d;
+       std::vector<cache_item_ptr> d;
        unsigned int generation_id;
+
+       explicit order_generation(std::size_t nelts, unsigned id) : generation_id(id) {
+               d.reserve(nelts);
+       }
 };
 
 using order_generation_ptr = std::shared_ptr<order_generation>;
 
-
 class symcache;
 
 struct item_condition {
@@ -269,6 +272,9 @@ private:
        /* Internal methods */
        auto load_items() -> bool;
        auto save_items() const -> bool;
+       auto resort() -> void;
+       /* Helper for g_hash_table_foreach */
+       static auto metric_connect_cb(void *k, void *v, void *ud) -> void;
 
 public:
        explicit symcache(struct rspamd_config *cfg) : cfg(cfg) {