]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Implement redis classification
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 7 Jan 2016 18:19:53 +0000 (18:19 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 7 Jan 2016 18:19:53 +0000 (18:19 +0000)
src/libstat/backends/redis_backend.c
src/libstat/stat_internal.h
src/libstat/stat_process.c

index e60ed3826e3cfb52cd7eaf3e5836a070adfa935d..7a3d59e2e3de35dc722f0c66562f0d6f63f0d94d 100644 (file)
@@ -46,6 +46,12 @@ struct redis_stat_ctx {
        gdouble timeout;
 };
 
+enum rspamd_redis_connection_state {
+       RSPAMD_REDIS_DISCONNECTED = 0,
+       RSPAMD_REDIS_CONNECTED,
+       RSPAMD_REDIS_TIMEDOUT
+};
+
 struct redis_stat_runtime {
        struct redis_stat_ctx *ctx;
        struct rspamd_task *task;
@@ -55,7 +61,8 @@ struct redis_stat_runtime {
        gchar *redis_object_expanded;
        redisAsyncContext *redis;
        guint64 learned;
-       gboolean connected;
+       gint id;
+       enum rspamd_redis_connection_state conn_state;
 };
 
 #define GET_TASK_ELT(task, elt) (task == NULL ? NULL : (task)->elt)
@@ -269,13 +276,48 @@ rspamd_redis_expand_object (const gchar *pattern,
        return tlen;
 }
 
+static rspamd_fstring_t *
+rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
+               const gchar *arg0, const gchar *arg1)
+{
+       rspamd_fstring_t *out;
+       rspamd_token_t *tok;
+       gchar numbuf[64];
+       guint i, l0, l1;
+       guint64 num;
+
+       g_assert (tokens != NULL);
+
+       l0 = strlen (arg0);
+       l1 = strlen (arg1);
+       out = rspamd_fstring_sized_new (1024);
+       rspamd_printf_fstring (&out, "*%d\r\n$%d\r\n%s\r\n$%d\r\n%s\r\n",
+                       tokens->len + 2,
+                       l0, arg0,
+                       l1, arg1);
+
+       for (i = 0; i < tokens->len; i ++) {
+               tok = g_ptr_array_index (tokens, i);
+               memcpy (&num, tok->data, sizeof (num));
+               l0 = rspamd_snprintf (numbuf, sizeof (numbuf), "%uL", num);
+               rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, numbuf);
+       }
+
+       rspamd_mempool_add_destructor (task->task_pool,
+                       (rspamd_mempool_destruct_t)rspamd_fstring_free, out);
+
+       return out;
+}
+
 /* Called on connection termination */
 static void
 rspamd_redis_fin (gpointer data)
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
 
-       redisAsyncFree (rt->redis);
+       if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
+               redisAsyncFree (rt->redis);
+       }
        event_del (&rt->timeout_event);
 }
 
@@ -290,6 +332,7 @@ rspamd_redis_timeout (gint fd, short what, gpointer d)
        msg_err_task ("connection to redis server %s timed out",
                        rspamd_upstream_name (rt->selected));
        rspamd_upstream_fail (rt->selected);
+       rt->conn_state = RSPAMD_REDIS_TIMEDOUT;
        rspamd_session_remove_event (task->s, rspamd_redis_fin, d);
 }
 
@@ -312,10 +355,68 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
                                rt->learned = 0;
                        }
 
-                       rt->connected = TRUE;
+                       rt->conn_state = RSPAMD_REDIS_CONNECTED;
 
                        msg_debug_task ("connected to redis server, tokens learned for %s: %d",
                                        rt->redis_object_expanded, rt->learned);
+                       rspamd_upstream_ok (rt->selected);
+                       rspamd_session_remove_event (task->s, rspamd_redis_fin, rt);
+               }
+               else {
+                       msg_err_task ("error getting reply from redis server %s: %s",
+                                       rspamd_upstream_name (rt->selected), c->errstr);
+                       rspamd_upstream_fail (rt->selected);
+                       rspamd_session_remove_event (task->s, rspamd_redis_fin, rt);
+               }
+       }
+       else {
+               msg_err_task ("error getting reply from redis server %s: %s",
+                               rspamd_upstream_name (rt->selected), c->errstr);
+               rspamd_upstream_fail (rt->selected);
+               rspamd_session_remove_event (task->s, rspamd_redis_fin, rt);
+       }
+}
+
+/* Called when we have received tokens values from redis */
+static void
+rspamd_redis_processed (redisAsyncContext *c, gpointer r, gpointer priv)
+{
+       struct redis_stat_runtime *rt = REDIS_RUNTIME (priv);
+       redisReply *reply = r, *elt;
+       struct rspamd_task *task;
+       rspamd_token_t *tok;
+       guint i, processed = 0, found = 0;
+
+       task = rt->task;
+
+       if (c->err == 0) {
+               if (r != NULL) {
+                       if (reply->type == REDIS_REPLY_ARRAY) {
+
+                               if (reply->elements == task->tokens->len) {
+                                       for (i = 0; i < reply->elements; i ++) {
+                                               elt = reply->element[i];
+
+                                               if (elt->type == REDIS_REPLY_INTEGER) {
+                                                       tok = g_ptr_array_index (task->tokens, i);
+                                                       tok->values[rt->id] = elt->integer;
+                                                       found ++;
+                                               }
+                                               else {
+                                                       tok->values[rt->id] = 0;
+                                               }
+
+                                               processed ++;
+                                       }
+                               }
+                       }
+                       else {
+                       }
+
+                       msg_debug_task ("received tokens for %s: %d processed, %d found",
+                                       rt->redis_object_expanded, processed, found);
+                       rspamd_upstream_ok (rt->selected);
+                       rspamd_session_remove_event (task->s, rspamd_redis_fin, rt);
                }
                else {
                        msg_err_task ("error getting reply from redis server %s: %s",
@@ -446,6 +547,8 @@ rspamd_redis_runtime (struct rspamd_task *task,
                        &rt->redis_object_expanded);
        rt->selected = up;
        rt->task = task;
+       rt->ctx = ctx;
+       rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
 
        addr = rspamd_upstream_addr (up);
        g_assert (addr != NULL);
@@ -462,6 +565,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
        event_base_set (task->ev_base, &rt->timeout_event);
        double_to_tv (ctx->timeout, &tv);
        event_add (&rt->timeout_event, &tv);
+
        redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s",
                        rt->redis_object_expanded, "learned");
 
@@ -490,6 +594,35 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
                gint id, gpointer p)
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (p);
+       rspamd_fstring_t *query;
+       struct timeval tv;
+       gint ret;
+
+       if (tokens == NULL || tokens->len == 0 || rt->redis == NULL) {
+               return FALSE;
+       }
+
+       rt->id = id;
+       query = rspamd_redis_tokens_to_query (task, tokens,
+                       "HMGET", rt->redis_object_expanded);
+       g_assert (query != NULL);
+
+       ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt,
+                       query->str, query->len);
+       if (ret == REDIS_OK) {
+               rspamd_session_add_event (task->s, rspamd_redis_fin, rt,
+                               rspamd_redis_stat_quark ());
+               /* Reset timeout */
+               event_del (&rt->timeout_event);
+               double_to_tv (rt->ctx->timeout, &tv);
+               event_add (&rt->timeout_event, &tv);
+
+               return TRUE;
+       }
+       else {
+               msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+               g_assert (0);
+       }
 
        return FALSE;
 }
index 31257938d629e8b2a9a0933dde72d605b5550e6b..787323fbc02244396aca69365120ef685d2cb7b1 100644 (file)
@@ -57,7 +57,7 @@ struct rspamd_statfile {
        gpointer bkcf;
 };
 
-#define RSPAMD_MAX_TOKEN_LEN 16
+#define RSPAMD_MAX_TOKEN_LEN 8
 typedef struct token_node_s {
        guchar data[RSPAMD_MAX_TOKEN_LEN];
        guint window_idx;
index 9eeaeaf10528ed6d0202517c925670b603229703..fe64eb65bdf2ee34e0372ff4bf95367314f3f211 100644 (file)
@@ -362,6 +362,8 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,
                rspamd_stat_classifiers_process (st_ctx, task);
        }
 
+       task->processed_stages |= stage;
+
        return ret;
 }
 
@@ -640,6 +642,8 @@ rspamd_stat_learn (struct rspamd_task *task,
                }
        }
 
+       task->processed_stages |= stage;
+
        return ret;
 }