if err then
callback(task, false, err)
else
- -- Handle both binary and multi-class results
- if type(data[1]) == "table" then
- -- Multi-class format: [learned_counts_table, outputs_table]
- -- Convert to binary format for backward compatibility if needed
- local learned_counts = data[1]
- local outputs = data[2]
-
- -- For now, return ham/spam data if available for backward compatibility
- local learned_ham = learned_counts["H"] or learned_counts["ham"] or 0
- local learned_spam = learned_counts["S"] or learned_counts["spam"] or 0
- local output_ham = outputs["H"] or outputs["ham"] or {}
- local output_spam = outputs["S"] or outputs["spam"] or {}
-
- callback(task, true, learned_ham, learned_spam, output_ham, output_spam)
- else
- -- Binary format: [learned_ham, learned_spam, output_ham, output_spam]
- callback(task, true, data[1], data[2], data[3], data[4])
- end
+ -- Pass the raw data table to the C++ callback for processing
+ -- The C++ callback will handle both binary and multi-class formats
+ callback(task, true, data)
end
end
bool result = lua_toboolean(L, 2);
if (result) {
- /* Check if this is binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
- * or multi-class format [learned_counts_table, outputs_table]
+ /* Check we have enough arguments and the result data is a table */
+ if (lua_gettop(L) < 3 || !lua_istable(L, 3)) {
+ msg_err_task("internal error: expected table result from Redis script, got %s",
+ lua_typename(L, lua_type(L, 3)));
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
+ return 0;
+ }
+
+ /* Check the array length to determine format:
+ * - Length 4: binary format [learned_ham, learned_spam, ham_tokens, spam_tokens]
+ * - Length 2: multi-class format [learned_counts_table, outputs_table]
*/
+ size_t result_len = rspamd_lua_table_size(L, 3);
+ msg_debug_bayes("Redis result array length: %uz", result_len);
auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
rt->learned = learned;
rt->set_results(res);
};
- /* Check if result[3] is a number (binary) or table (multi-class) */
- lua_rawgeti(L, 3, 1); /* Get first element of result array */
- bool is_binary_format = lua_isnumber(L, -1);
- lua_pop(L, 1);
+ bool is_binary_format = (result_len == 4);
if (is_binary_format) {
/* Binary format: [learned_ham, learned_spam, ham_tokens, spam_tokens] */
else {
/* Multi-class format: [learned_counts_table, outputs_table] */
- /* Get learned counts table (index 3) and outputs table (index 4) */
+ /* Get learned counts table (index 1) and outputs table (index 2) */
lua_rawgeti(L, 3, 1); /* learned_counts */
lua_rawgeti(L, 3, 2); /* outputs */
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot classify task: %s",
- err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot classify task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot classify task: unknown Redis script error");
+ }
}
return 0;
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot learn task: %s", err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot learn task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot learn task: unknown Redis script error");
+ }
}
return 0;
-/*-
- * Copyright 2016 Vsevolod Stakhov
+/*
+ * Copyright 2025 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
for (i = 0; i < cl.num_classes; i++) {
if (cl.class_learns[i] < ctx->cfg->min_learns) {
msg_info_task("not classified as %s. The class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
return TRUE;
}
}
if (has_class_names) {
- msg_debug_bayes("using multiclass classification with %u classes",
+ msg_debug_bayes("using multiclass classification with %ud classes",
(unsigned int) ctx->cfg->class_names->len);
return bayes_classify_multiclass(ctx, tokens, task);
}
if (ctx->cfg->min_learns > 0) {
if (ctx->ham_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as ham. The ham class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->ham_learns, ctx->cfg->min_learns);
return TRUE;
}
if (ctx->spam_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as spam. The spam class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->spam_learns, ctx->cfg->min_learns);
return TRUE;
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes(
"got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens;"
+ " %uL tokens processed of %ud total tokens;"
" %uL text tokens found of %ud text tokens)",
cl.ham_prob,
h,