]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Fix various issues
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 21 Jul 2025 09:07:27 +0000 (10:07 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 21 Jul 2025 09:07:27 +0000 (10:07 +0100)
lualib/lua_bayes_redis.lua
src/libstat/backends/redis_backend.cxx
src/libstat/classifiers/bayes.c

index 59952131ab0d9a0c7869090bb2eab04e8d397b62..c086669fed13e4bf26398de48d6b16e2c029d439 100644 (file)
@@ -31,24 +31,9 @@ local function gen_classify_functor(redis_params, classify_script_id)
       if err then
         callback(task, false, err)
       else
-        -- Handle both binary and multi-class results
-        if type(data[1]) == "table" then
-          -- Multi-class format: [learned_counts_table, outputs_table]
-          -- Convert to binary format for backward compatibility if needed
-          local learned_counts = data[1]
-          local outputs = data[2]
-
-          -- For now, return ham/spam data if available for backward compatibility
-          local learned_ham = learned_counts["H"] or learned_counts["ham"] or 0
-          local learned_spam = learned_counts["S"] or learned_counts["spam"] or 0
-          local output_ham = outputs["H"] or outputs["ham"] or {}
-          local output_spam = outputs["S"] or outputs["spam"] or {}
-
-          callback(task, true, learned_ham, learned_spam, output_ham, output_spam)
-        else
-          -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
-          callback(task, true, data[1], data[2], data[3], data[4])
-        end
+        -- Pass the raw data table to the C++ callback for processing
+        -- The C++ callback will handle both binary and multi-class formats
+        callback(task, true, data)
       end
     end
 
index 01ed818c47376bc4aac2127a626f33732094ab59..0fe738de50f9e852d08e31f6c0938e8870dd6955 100644 (file)
@@ -866,9 +866,20 @@ rspamd_redis_classified(lua_State *L)
        bool result = lua_toboolean(L, 2);
 
        if (result) {
-               /* Check if this is binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
-                * or multi-class format [learned_counts_table, outputs_table]
+               /* Check we have enough arguments and the result data is a table */
+               if (lua_gettop(L) < 3 || !lua_istable(L, 3)) {
+                       msg_err_task("internal error: expected table result from Redis script, got %s",
+                                                lua_typename(L, lua_type(L, 3)));
+                       rt->err = rspamd::util::error("invalid Redis script result format", 500);
+                       return 0;
+               }
+
+               /* Check the array length to determine format:
+                * - Length 4: binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
+                * - Length 2: multi-class format [learned_counts_table, outputs_table]
                 */
+               size_t result_len = rspamd_lua_table_size(L, 3);
+               msg_debug_bayes("Redis result array length: %uz", result_len);
 
                auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
                        rt->learned = learned;
@@ -891,10 +902,7 @@ rspamd_redis_classified(lua_State *L)
                        rt->set_results(res);
                };
 
-               /* Check if result[3] is a number (binary) or table (multi-class) */
-               lua_rawgeti(L, 3, 1); /* Get first element of result array */
-               bool is_binary_format = lua_isnumber(L, -1);
-               lua_pop(L, 1);
+               bool is_binary_format = (result_len == 4);
 
                if (is_binary_format) {
                        /* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
@@ -935,7 +943,7 @@ rspamd_redis_classified(lua_State *L)
                else {
                        /* Multi-class format: [learned_counts_table, outputs_table] */
 
-                       /* Get learned counts table (index 3) and outputs table (index 4) */
+                       /* Get learned counts table (index 1) and outputs table (index 2) */
                        lua_rawgeti(L, 3, 1); /* learned_counts */
                        lua_rawgeti(L, 3, 2); /* outputs */
 
@@ -984,10 +992,18 @@ rspamd_redis_classified(lua_State *L)
        }
        else {
                /* Error message is on index 3 */
-               const auto *err_msg = lua_tostring(L, 3);
-               rt->err = rspamd::util::error(err_msg, 500);
-               msg_err_task("cannot classify task: %s",
-                                        err_msg);
+               const char *err_msg = nullptr;
+               if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+                       err_msg = lua_tostring(L, 3);
+               }
+               if (err_msg) {
+                       rt->err = rspamd::util::error(err_msg, 500);
+                       msg_err_task("cannot classify task: %s", err_msg);
+               }
+               else {
+                       rt->err = rspamd::util::error("unknown Redis script error", 500);
+                       msg_err_task("cannot classify task: unknown Redis script error");
+               }
        }
 
        return 0;
@@ -1102,9 +1118,18 @@ rspamd_redis_learned(lua_State *L)
        }
        else {
                /* Error message is on index 3 */
-               const auto *err_msg = lua_tostring(L, 3);
-               rt->err = rspamd::util::error(err_msg, 500);
-               msg_err_task("cannot learn task: %s", err_msg);
+               const char *err_msg = nullptr;
+               if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+                       err_msg = lua_tostring(L, 3);
+               }
+               if (err_msg) {
+                       rt->err = rspamd::util::error(err_msg, 500);
+                       msg_err_task("cannot learn task: %s", err_msg);
+               }
+               else {
+                       rt->err = rspamd::util::error("unknown Redis script error", 500);
+                       msg_err_task("cannot learn task: unknown Redis script error");
+               }
        }
 
        return 0;
index 4a1b0cf32add147bc81b47d51588ce6c5b2d8610..836c27c88d42307eacebb141bb2d45458ffd8510 100644 (file)
@@ -1,11 +1,11 @@
-/*-
- * Copyright 2016 Vsevolod Stakhov
+/*
+ * Copyright 2025 Vsevolod Stakhov
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *   http://www.apache.org/licenses/LICENSE-2.0
+ *    http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -422,7 +422,7 @@ bayes_classify_multiclass(struct rspamd_classifier *ctx,
                for (i = 0; i < cl.num_classes; i++) {
                        if (cl.class_learns[i] < ctx->cfg->min_learns) {
                                msg_info_task("not classified as %s. The class needs more "
-                                                         "training samples. Currently: %ul; minimum %ud required",
+                                                         "training samples. Currently: %uL; minimum %ud required",
                                                          cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
                                return TRUE;
                        }
@@ -602,7 +602,7 @@ bayes_classify(struct rspamd_classifier *ctx,
                }
 
                if (has_class_names) {
-                       msg_debug_bayes("using multiclass classification with %u classes",
+                       msg_debug_bayes("using multiclass classification with %ud classes",
                                                        (unsigned int) ctx->cfg->class_names->len);
                        return bayes_classify_multiclass(ctx, tokens, task);
                }
@@ -617,14 +617,14 @@ bayes_classify(struct rspamd_classifier *ctx,
        if (ctx->cfg->min_learns > 0) {
                if (ctx->ham_learns < ctx->cfg->min_learns) {
                        msg_info_task("not classified as ham. The ham class needs more "
-                                                 "training samples. Currently: %ul; minimum %ud required",
+                                                 "training samples. Currently: %uL; minimum %ud required",
                                                  ctx->ham_learns, ctx->cfg->min_learns);
 
                        return TRUE;
                }
                if (ctx->spam_learns < ctx->cfg->min_learns) {
                        msg_info_task("not classified as spam. The spam class needs more "
-                                                 "training samples. Currently: %ul; minimum %ud required",
+                                                 "training samples. Currently: %uL; minimum %ud required",
                                                  ctx->spam_learns, ctx->cfg->min_learns);
 
                        return TRUE;
@@ -705,7 +705,7 @@ bayes_classify(struct rspamd_classifier *ctx,
                final_prob = (s + 1.0 - h) / 2.;
                msg_debug_bayes(
                        "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
-                       " %L tokens processed of %ud total tokens;"
+                       " %uL tokens processed of %ud total tokens;"
                        " %uL text tokens found of %ud text tokens)",
                        cl.ham_prob,
                        h,