]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Rework stat runtime
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 6 Dec 2023 14:46:45 +0000 (14:46 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 6 Dec 2023 14:46:45 +0000 (14:46 +0000)
src/libstat/backends/redis_backend.cxx

index 973e60671860f843f35d83e0c004638c8735f842..46b27cb155e5acb3fcc625849f519312aba8987b 100644 (file)
 #include "stat_internal.h"
 #include "upstream.h"
 #include "libserver/mempool_vars_internal.h"
+#include "fmt/core.h"
+
+#include <string>
+#include <cstdint>
+#include <vector>
 
 #define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr,                                                 \
                                                                                                                                rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
@@ -28,7 +33,7 @@
 INIT_LOG_MODULE(stat_redis)
 
 #define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
-#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime *>(p))
+#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
 #define REDIS_DEFAULT_OBJECT "%s%l"
 #define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
 #define REDIS_DEFAULT_TIMEOUT 0.5
@@ -38,31 +43,68 @@ INIT_LOG_MODULE(stat_redis)
 struct redis_stat_ctx {
        lua_State *L;
        struct rspamd_statfile_config *stcf;
-       gint conf_ref;
        struct rspamd_stat_async_elt *stat_elt;
-       const char *redis_object;
-       gboolean enable_users;
-       gboolean store_tokens;
-       gboolean new_schema;
-       gboolean enable_signatures;
-       guint expiry;
-       guint max_users;
-       gint cbref_user;
-
-       gint cbref_classify;
-       gint cbref_learn;
+       const char *redis_object = REDIS_DEFAULT_OBJECT;
+       bool enable_users = false;
+       bool store_tokens = false;
+       bool enable_signatures = false;
+       unsigned expiry;
+       unsigned max_users = REDIS_MAX_USERS;
+       int cbref_user = -1;
+
+       int cbref_classify = -1;
+       int cbref_learn = -1;
+       int conf_ref = -1;
 };
 
 
+template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
 struct redis_stat_runtime {
        struct redis_stat_ctx *ctx;
        struct rspamd_task *task;
        struct rspamd_statfile_config *stcf;
        GPtrArray *tokens;
-       gchar *redis_object_expanded;
-       guint64 learned;
-       gint id;
-       GError *err;
+       const char *redis_object_expanded;
+       std::uint64_t learned = 0;
+       int id;
+       std::vector<std::pair<int, T>> *results = nullptr;
+
+       using result_type = std::vector<std::pair<int, T>>;
+
+       explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+               : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+       {
+       }
+
+       void init()
+       {
+       }
+
+       void set_results(std::vector<std::pair<int, T>> *_results)
+       {
+               results = _results;
+       }
+
+       ~redis_stat_runtime()
+       {
+               g_ptr_array_unref(tokens);
+               delete results;
+       }
+
+       /* Propagate results from internal representation to the tokens array */
+       auto process_tokens(GPtrArray *tokens) const -> bool
+       {
+               rspamd_token_t *tok;
+
+               if (!results) {
+                       return false;
+               }
+
+               for (auto [idx, val]: *results) {
+                       tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
+                       tok->values[id] = val;
+               }
+       }
 };
 
 /* Used to get statistics from redis */
@@ -217,14 +259,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
                                /* Label miss is OK */
                                break;
                        case 's':
-                               if (ctx->new_schema) {
-                                       tlen += sizeof("RS") - 1;
-                               }
-                               else {
-                                       if (stcf->symbol) {
-                                               tlen += strlen(stcf->symbol);
-                                       }
-                               }
+                               tlen += sizeof("RS") - 1;
                                break;
                        default:
                                state = just_char;
@@ -306,14 +341,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
                                }
                                break;
                        case 's':
-                               if (ctx->new_schema) {
-                                       d += rspamd_strlcpy(d, "RS", end - d);
-                               }
-                               else {
-                                       if (stcf->symbol) {
-                                               d += rspamd_strlcpy(d, stcf->symbol, end - d);
-                                       }
-                               }
+                               d += rspamd_strlcpy(d, "RS", end - d);
                                break;
                        default:
                                state = just_char;
@@ -1071,15 +1099,9 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d)
 static void
 rspamd_redis_fin(gpointer data)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(data);
-
-       if (rt->err) {
-               g_error_free(rt->err);
-       }
+       auto *rt = REDIS_RUNTIME(data);
 
-       if (rt->tokens) {
-               g_ptr_array_unref(rt->tokens);
-       }
+       delete rt;
 }
 
 
@@ -1260,7 +1282,6 @@ rspamd_redis_runtime(struct rspamd_task *task,
                                         gboolean learn, gpointer c, gint _id)
 {
        struct redis_stat_ctx *ctx = REDIS_CTX(c);
-       struct redis_stat_runtime *rt;
        char *object_expanded = nullptr;
 
        g_assert(ctx != nullptr);
@@ -1275,16 +1296,18 @@ rspamd_redis_runtime(struct rspamd_task *task,
                return nullptr;
        }
 
-       /* Look for the cached results */
-
+       auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+       rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
 
-       rt = (struct redis_stat_runtime *) rspamd_mempool_alloc0(task->task_pool, sizeof(*rt));
-       rt->task = task;
-       rt->ctx = ctx;
-       rt->redis_object_expanded = object_expanded;
-       rt->stcf = stcf;
+       /* Look for the cached results */
+       if (!learn) {
+               auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
+               auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
 
-       rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
+               if (res) {
+                       rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+               }
+       }
 
        return rt;
 }
@@ -1348,9 +1371,9 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize
 static gint
 rspamd_redis_classified(lua_State *L)
 {
-       const gchar *cookie = lua_tostring(L, lua_upvalueindex(1));
-       struct rspamd_task *task = lua_check_task(L, 1);
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+       const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+       auto *task = lua_check_task(L, 1);
+       auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
        /* TODO: write it */
 
        if (rt == nullptr) {
@@ -1374,8 +1397,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
                                                        GPtrArray *tokens,
                                                        gint id, gpointer p)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
-       lua_State *L = rt->ctx->L;
+       auto *rt = REDIS_RUNTIME(p);
+       auto *L = rt->ctx->L;
 
        if (rspamd_session_blocked(task->s)) {
                return FALSE;
@@ -1385,7 +1408,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
                return FALSE;
        }
 
-       /* TODO: check if we have tokens for that particular id for this class */
+       if (rt->results) {
+               /* No need to do anything, we have results ready */
+               rt->process_tokens(tokens);
+
+               return TRUE;
+       }
 
        gsize tokens_len;
        gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len);
@@ -1429,19 +1457,6 @@ gboolean
 rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
                                                          gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
-       if (rt->err) {
-               msg_info_task("cannot retrieve stat tokens from Redis: %e", rt->err);
-               g_error_free(rt->err);
-               rt->err = nullptr;
-               rspamd_redis_fin(rt);
-
-               return FALSE;
-       }
-
-       rspamd_redis_fin(rt);
-
        return TRUE;
 }
 
@@ -1449,7 +1464,7 @@ gboolean
 rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
                                                  gint id, gpointer p)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
+       auto *rt = REDIS_RUNTIME(p);
 
        /* TODO: write learn function */
 
@@ -1461,18 +1476,6 @@ gboolean
 rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
                                                        gpointer ctx, GError **err)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
-       if (rt->err) {
-               g_propagate_error(err, rt->err);
-               rt->err = nullptr;
-               rspamd_redis_fin(rt);
-
-               return FALSE;
-       }
-
-       rspamd_redis_fin(rt);
-
        return TRUE;
 }
 
@@ -1480,7 +1483,7 @@ gulong
 rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
                                                  gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+       auto *rt = REDIS_RUNTIME(runtime);
 
        return rt->learned;
 }
@@ -1489,7 +1492,7 @@ gulong
 rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
                                                gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+       auto *rt = REDIS_RUNTIME(runtime);
 
        /* XXX: may cause races */
        return rt->learned + 1;
@@ -1499,7 +1502,7 @@ gulong
 rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
                                                gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+       auto *rt = REDIS_RUNTIME(runtime);
 
        /* XXX: may cause races */
        return rt->learned + 1;
@@ -1509,7 +1512,7 @@ gulong
 rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
                                        gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+       auto *rt = REDIS_RUNTIME(runtime);
 
        return rt->learned;
 }
@@ -1518,7 +1521,7 @@ ucl_object_t *
 rspamd_redis_get_stat(gpointer runtime,
                                          gpointer ctx)
 {
-       struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+       auto *rt = REDIS_RUNTIME(runtime);
        struct rspamd_redis_stat_elt *st;
        redisAsyncContext *redis;