local class_name = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
--- Handle backward compatibility for binary values
-if class_name == "1" then
- class_name = "spam"
-elseif class_name == "0" then
- class_name = "ham"
+-- Convert class names to numeric cache values for consistency
+local cache_value
+if class_name == "1" or class_name == "spam" or class_name == "S" then
+ cache_value = "1" -- spam
+elseif class_name == "0" or class_name == "ham" or class_name == "H" then
+ cache_value = "0" -- ham
+else
+ -- For other classes, use a simple hash to get a consistent numeric value
+ -- This ensures cache check can return a number while preserving class info
+ local hash = 0
+ for i = 1, #class_name do
+ hash = hash + string.byte(class_name, i)
+ end
+ cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1
end
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
local have = redis.call('HGET', prefix, cache_id)
if have then
- -- Already in cache, but class_name changes when relearning
- redis.call('HSET', prefix, cache_id, class_name)
+ -- Already in cache, but cache_value changes when relearning
+ redis.call('HSET', prefix, cache_id, cache_value)
return false
end
end
if count < lim then
-- We can add it to this prefix
- redis.call('HSET', prefix, cache_id, class_name)
+ redis.call('HSET', prefix, cache_id, cache_value)
added = true
end
end
if exists then
if not expired then
redis.call('DEL', prefix)
- redis.call('HSET', prefix, cache_id, class_name)
+ redis.call('HSET', prefix, cache_id, cache_value)
-- Do not expire anything else
expired = true
-- Determine if this is multi-class (table) or binary (string)
local class_labels = {}
-if type(class_labels_arg) == "table" then
- class_labels = class_labels_arg
+
+-- Check if this is a table serialized as "TABLE:label1,label2,..."
+if string.match(class_labels_arg, "^TABLE:") then
+ local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
+ -- Split by comma
+ for label in string.gmatch(labels_str, "([^,]+)") do
+ table.insert(class_labels, label)
+ end
else
-- Binary compatibility: handle old boolean or single string format
if class_labels_arg == "true" then
- class_labels = { "S" } -- spam
+ class_labels = { "S" } -- spam
elseif class_labels_arg == "false" then
- class_labels = { "H" } -- ham
+ class_labels = { "H" } -- ham
else
class_labels = { class_labels_arg } -- single class label
end
return 0;
}
- if (rt->stcf->is_spam) {
- filler_func(rt, L, lua_tointeger(L, 4), 6);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ /* Extract values from the result table at position 3 */
+ lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */
+ lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */
+ lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */
+ lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */
+
+ unsigned learned_ham = lua_tointeger(L, 4);
+ unsigned learned_spam = lua_tointeger(L, 5);
+
+ if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) {
+ /* Current runtime is spam, use spam data */
+ filler_func(rt, L, learned_spam, 7); /* spam_tokens at position 7 */
+ filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */
}
else {
- filler_func(rt, L, lua_tointeger(L, 3), 5);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+ /* Current runtime is ham, use ham data */
+ filler_func(rt, L, learned_ham, 6); /* ham_tokens at position 6 */
+ filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */
}
+ /* Clean up the stack - pop the 4 extracted values */
+ lua_pop(L, 4);
+
/* Process all tokens */
g_assert(rt->tokens != nullptr);
rt->process_tokens(rt->tokens);
lua_pushinteger(L, id);
/* Send all class labels for multi-class support */
- if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
+ if (rt->stcf->clcf && rt->stcf->clcf->class_labels &&
+ g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
/* Multi-class: send array of all class labels */
- lua_newtable(L);
+ lua_createtable(L, g_hash_table_size(rt->stcf->clcf->class_labels), 0);
GHashTableIter iter;
gpointer key, value;
int idx = 1;
}
}
else {
- /* Binary compatibility: send current class label as single string */
- lua_pushstring(L, get_class_label(rt->stcf));
+ /* Binary classification: send both spam and ham labels for optimization */
+ lua_createtable(L, 2, 0);
+ lua_pushstring(L, "H"); /* ham */
+ lua_rawseti(L, -2, 1);
+ lua_pushstring(L, "S"); /* spam */
+ lua_rawseti(L, -2, 2);
}
lua_new_text(L, tokens_buf, tokens_len, false);
bool result = lua_toboolean(L, 2);
if (result) {
- /* TODO: write it */
+ /* Learning successful - no complex data to process like in classification */
+ msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s",
+ rt->stcf->symbol, get_class_label(rt->stcf));
+
+ /* Clear any previous error state */
+ rt->err = std::nullopt;
+
+ /* Learning operations don't return data structures to process,
+ * they just update Redis state. Success means the Redis script
+ * completed without errors. */
}
else {
/* Error message is on index 3 */