local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, class_name)
+ return function(task, cache_id, class_name, class_id)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
end
- -- Handle backward compatibility for boolean values
- local cache_class_name = class_name
- if type(class_name) == "boolean" then
- cache_class_name = class_name and "spam" or "ham"
- elseif class_name == true or class_name == "true" then
- cache_class_name = "spam"
- elseif class_name == false or class_name == "false" then
- cache_class_name = "ham"
- end
-
- lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name)
+ lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id)
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = cache_id },
learn_redis_cb,
- { cache_id, cache_class_name, packed_conf })
+ { cache_id, tostring(class_id), packed_conf })
end
end
-- Lua script to perform cache checking for bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - cache id
--- key2 - class name (e.g. "spam", "ham", "transactional")
+-- key2 - class_id (numeric hash of class name, computed by C side)
-- key3 - configuration table in message pack
local cache_id = KEYS[1]
-local class_name = KEYS[2]
+local class_id = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
--- Convert class names to numeric cache values for consistency
-local cache_value
-if class_name == "1" or class_name == "spam" or class_name == "S" then
- cache_value = "1" -- spam
-elseif class_name == "0" or class_name == "ham" or class_name == "H" then
- cache_value = "0" -- ham
-else
- -- For other classes, use a simple hash to get a consistent numeric value
- -- This ensures cache check can return a number while preserving class info
- local hash = 0
- for i = 1, #class_name do
- hash = hash + string.byte(class_name, i)
- end
- cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1
-end
+-- Use class_id directly as cache value
+local cache_value = tostring(class_id)
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
-- Try each prefix that is in Redis (as some other instance might have set it)
return (void *) ctx;
}
+/* Get class ID using rspamd_cryptobox_fast_hash */
+static uint64_t
+rspamd_stat_cache_get_class_id(const char *class_name)
+{
+ if (!class_name) {
+ return 0;
+ }
+
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ return 1;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ return 0;
+ }
+ else {
+ /* For other classes, use rspamd_cryptobox_fast_hash */
+ uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0);
+
+ /* Ensure we don't get 0 or 1 (reserved for ham/spam) */
+ if (hash == 0 || hash == 1) {
+ hash += 2;
+ }
+
+ return hash;
+ }
+}
+
static int
rspamd_stat_cache_checked(lua_State *L)
{
if (res) {
auto val = lua_tointeger(L, 3);
- if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
- (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
- /* Already learned */
- msg_info_task("<%s> has been already "
- "learned as %s, ignore it",
- MESSAGE_FIELD(task, message_id),
- (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
- task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ /* Get the class being learned */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ autolearn_class = "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ autolearn_class = "ham";
+ }
}
- else {
- /* Unlearn flag */
- task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+
+ if (autolearn_class) {
+ uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class);
+
+ if ((uint64_t) val == expected_id) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else {
+ /* Different class learned, unlearn flag */
+ msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn",
+ MESSAGE_FIELD(task, message_id),
+ val, expected_id, autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
}
}
- /* Ignore errors for now, as we can do nothing about them at the moment */
-
return 0;
}
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
rspamd_lua_task_push(L, task);
lua_pushstring(L, h);
- lua_pushboolean(L, is_spam);
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ /* Get the class being learned - prefer multiclass over binary */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flag for backward compatibility */
+ autolearn_class = is_spam ? "spam" : "ham";
+ }
+
+ /* Push class name and class ID */
+ lua_pushstring(L, autolearn_class);
+ uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class);
+ lua_pushinteger(L, class_id);
+
+ if (lua_pcall(L, 4, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return RSPAMD_LEARN_IGNORE;