]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fix learn error propagation
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 22 Jan 2024 14:36:12 +0000 (14:36 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 22 Jan 2024 14:42:07 +0000 (14:42 +0000)
src/libstat/backends/redis_backend.cxx

index 4d94293608656cb7ca9640fce614da1ecf1cfff1..cff6baf8cd0931146484875f0d04c78194077317 100644 (file)
@@ -21,6 +21,8 @@
 #include "libserver/mempool_vars_internal.h"
 #include "fmt/core.h"
 
+#include "libutil/cxx/error.hxx"
+
 #include <string>
 #include <cstdint>
 #include <vector>
@@ -88,6 +90,7 @@ struct redis_stat_runtime {
        int id;
        std::vector<std::pair<int, T>> *results = nullptr;
        bool need_redis_call = true;
+       std::optional<rspamd::util::error> err;
 
        using result_type = std::vector<std::pair<int, T>>;
 
@@ -864,8 +867,10 @@ rspamd_redis_classified(lua_State *L)
        }
        else {
                /* Error message is on index 3 */
+               const auto *err_msg = lua_tostring(L, 3);
+               rt->err = rspamd::util::error(err_msg, 500);
                msg_err_task("cannot classify task: %s",
-                                        lua_tostring(L, 3));
+                                        err_msg);
        }
 
        return 0;
@@ -935,7 +940,9 @@ gboolean
 rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
                                                          gpointer ctx)
 {
-       return TRUE;
+       auto *rt = REDIS_RUNTIME(runtime);
+
+       return !rt->err.has_value();
 }
 
 
@@ -959,8 +966,9 @@ rspamd_redis_learned(lua_State *L)
        }
        else {
                /* Error message is on index 3 */
-               msg_err_task("cannot learn task: %s",
-                                        lua_tostring(L, 3));
+               const auto *err_msg = lua_tostring(L, 3);
+               rt->err = rspamd::util::error(err_msg, 500);
+               msg_err_task("cannot learn task: %s", err_msg);
        }
 
        return 0;
@@ -1048,6 +1056,14 @@ gboolean
 rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
                                                        gpointer ctx, GError **err)
 {
+       auto *rt = REDIS_RUNTIME(runtime);
+
+       if (rt->err.has_value()) {
+               rt->err->into_g_error_set(rspamd_redis_stat_quark(), err);
+
+               return FALSE;
+       }
+
        return TRUE;
 }