]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Move all metatokens to lua_stat from C
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Nov 2018 16:44:06 +0000 (16:44 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Nov 2018 16:44:06 +0000 (16:44 +0000)
lualib/lua_stat.lua
src/libstat/stat_config.c
src/libstat/stat_process.c

index 0ced4b428c880079fbe28f943bcedf3ce087755b..cda80cf71e0ffe3f3b3db508f4d8ff7e11e9fa37 100644 (file)
@@ -533,17 +533,273 @@ local function process_stat_config(cfg)
       "X-Mailer",
       "Content-Type",
       "X-MimeOLE",
+      "Organization",
+      "Organisation"
     },
     classify_images = true,
     classify_mime_info = true,
-    classify_headers = true,
+    classify_urls = true,
+    classify_meta = true,
+    classify_max_tlds = 10,
   }
 
   res_config = lua_util.override_defaults(res_config, opts_section)
 
+  -- Postprocess classify_headers
+  local classify_headers_parsed = {}
+
+  for _,v in ipairs(res_config.classify_headers) do
+    local s1, s2 = v:match("^([A-Z])[^%-]+%-([A-Z]).*$")
+
+    local hname
+    if s1 and s2 then
+      hname = string.format('#h:%s-%s', s1, s2)
+    else
+      hname = string.format('#h:%s', v:sub(1, 2):lower())
+    end
+
+    if classify_headers_parsed[hname] then
+      table.insert(classify_headers_parsed[hname], v)
+    else
+      classify_headers_parsed[hname] = {v}
+    end
+  end
+
+  res_config.classify_headers_parsed = classify_headers_parsed
+
   return res_config
 end
 
+local function get_mime_stat_tokens(task, res, i)
+  local parts = task:get_parts() or {}
+  local seen_multipart = false
+  local seen_plain = false
+  local seen_html = false
+  local empty_plain = false
+  local empty_html = false
+  local online_text = false
+
+  for _,part in ipairs(parts) do
+    local fname = part:get_filename()
+
+    local sz = part:get_length()
+
+    if sz > 0 then
+      rawset(res, i, string.format("#ps:%d",
+          math.floor(math.log(sz))))
+      lua_util.debugm("bayes", task, "part size: %s",
+          res[i])
+      i = i + 1
+    end
+
+    if fname then
+      rawset(res, i, "#f:" .. fname)
+      i = i + 1
+
+      lua_util.debugm("bayes", task, "added attachment: #f:%s",
+          fname)
+    end
+
+    if part:is_text() then
+      local tp = part:get_text()
+
+      if tp:is_html() then
+        seen_html = true
+
+        if tp:get_length() == 0 then
+          empty_html = true
+        end
+      else
+        seen_plain = true
+
+        if tp:get_length() == 0 then
+          empty_plain = true
+        end
+      end
+
+      if tp:get_lines_count() < 2 then
+        online_text = true
+      end
+
+      rawset(res, i, "#lang:" .. tp:get_language() or 'unk')
+      lua_util.debugm("bayes", task, "added language: %s",
+          res[i])
+      i = i + 1
+
+      rawset(res, i, "#cs:" .. tp:get_charset() or 'unk')
+      lua_util.debugm("bayes", task, "added charset: %s",
+          res[i])
+      i = i + 1
+
+    elseif part:is_multipart() then
+      seen_multipart = true;
+    end
+  end
+
+  -- Create a special token depending on parts structure
+  local st_tok = "#unk"
+  if seen_multipart and seen_html and seen_plain then
+    st_tok = '#mpth'
+  end
+
+  if seen_html and not seen_plain then
+    st_tok = "#ho"
+  end
+
+  if seen_plain and not seen_html then
+    st_tok = "#to"
+  end
+
+  local spec_tok = ""
+  if online_text then
+    spec_tok = "#ot"
+  end
+
+  if empty_plain then
+    spec_tok = spec_tok .. "#ep"
+  end
+
+  if empty_html then
+    spec_tok = spec_tok .. "#eh"
+  end
+
+  rawset(res, i, string.format("#m:%s%s", st_tok, spec_tok))
+  lua_util.debugm("bayes", task, "added mime token: %s",
+      res[i])
+  i = i + 1
+
+  return i
+end
+
+local function get_headers_stat_tokens(task, cf, res, i)
+  local hdrs_cksum = task:get_mempool():get_variable("headers_hash")
+
+  if hdrs_cksum then
+    rawset(res, i, string.format("#hh:%s", hdrs_cksum:sub(1, 7)))
+    lua_util.debugm("bayes", task, "added hdrs hash token: %s",
+        res[i])
+    i = i + 1
+  end
+
+  for k,hdrs in pairs(cf.classify_headers_parsed) do
+    for _,hname in ipairs(hdrs) do
+      local value = task:get_header(hname)
+
+      if value then
+        rawset(res, i, string.format("#h:%s:%s", k, value))
+        lua_util.debugm("bayes", task, "added hdrs token: %s",
+            res[i])
+        i = i + 1
+      end
+    end
+  end
+
+  local from = (task:get_from('mime') or {})[1]
+
+  if from and from.name then
+    rawset(res, i, string.format("#F:%s", from.name))
+    lua_util.debugm("bayes", task, "added from name token: %s",
+        res[i])
+    i = i + 1
+  end
+
+  return i
+end
+
+local function get_meta_stat_tokens(task, res, i)
+  local pool = task:get_mempool()
+  local asn = pool:get_variable("asn")
+  local country = pool:get_variable("country")
+  local ipnet = pool:get_variable("ipnet")
+
+  if asn and country and ipnet then
+    rawset(res, i, string.format("#asn:%s", asn))
+    lua_util.debugm("bayes", task, "added asn token: %s",
+        res[i])
+    i = i + 1
+    rawset(res, i, string.format("#cnt:%s", country))
+    lua_util.debugm("bayes", task, "added country token: %s",
+        res[i])
+    i = i + 1
+    rawset(res, i, string.format("#ipn:%s", country))
+    lua_util.debugm("bayes", task, "added ipnet token: %s",
+        res[i])
+    i = i + 1
+  end
+
+  local pol = {}
+  if task:has_symbol('R_DKIM_ALLOW') then
+    table.insert(pol, 'D')
+  end
+  if task:has_symbol('R_SPF_ALLOW') then
+    table.insert(pol, 'S')
+  end
+
+  rawset(res, i, string.format("#pol:%s", table.concat(pol, '')))
+  lua_util.debugm("bayes", task, "added policies token: %s",
+      res[i])
+  i = i + 1
+
+  return i
+end
+
+local function get_stat_tokens(task, cf)
+  local res = {}
+  local E = {}
+  local i = 1
+
+  if cf.classify_images then
+    local images = task:get_images() or E
+
+    for _,img in ipairs(images) do
+      rawset(res, i, "image")
+      i = i + 1
+      rawset(res, i, tostring(img:get_height()))
+      i = i + 1
+      rawset(res, i, tostring(img:get_width()))
+      i = i + 1
+      rawset(res, i, tostring(img:get_type()))
+      i = i + 1
+
+      local fname = img:get_filename()
+
+      if fname then
+        rawset(res, i, tostring(img:get_filename()))
+        i = i + 1
+      end
+
+      lua_util.debugm("bayes", task, "added image: %s",
+          fname)
+    end
+  end
+
+  if cf.classify_mime_info then
+    i = get_mime_stat_tokens(task, res, i)
+  end
+
+  if cf.classify_headers and #cf.classify_headers > 0 then
+    i = get_headers_stat_tokens(task, cf, res, i)
+  end
+
+  if cf.classify_urls then
+    local urls = lua_util.extract_specific_urls{task = task, limit = 5}
+
+    if urls then
+      for _,u in ipairs(urls) do
+        rawset(res, i, string.format("#u:%s", u:get_tld()))
+        lua_util.debugm("bayes", task, "added url token: %s",
+            res[i])
+        i = i + 1
+      end
+    end
+  end
+
+  if cf.classify_meta then
+    i = get_meta_stat_tokens(task, res, i)
+  end
+
+  return res
+end
 
 exports.gen_stat_tokens = function(cfg)
   local stat_config = process_stat_config(cfg)
index b5ddf143caf8d6ab8b4445e9b38c58a8ed100741..101db4fe640da41d7b20605f5064b8b4afdf4f71 100644 (file)
@@ -202,7 +202,7 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
 
                                if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) {
                                        tb = lua_touserdata (L, -1);
-                                       msg_err_config ("call to cleanup_rules lua "
+                                       msg_err_config ("call to gen_stat_tokens lua "
                                                                        "script failed (%d): %v", ret, tb);
 
                                        if (tb) {
index d07e241562169951bd6d37a4f9911a729acbd9f3..0465f0c3c8693b796cf0a5125cf216052782e258 100644 (file)
 
 static const gdouble similarity_treshold = 80.0;
 
-static void
-rspamd_stat_tokenize_header (struct rspamd_task *task,
-               const gchar *name, const gchar *prefix, GArray *ar)
-{
-       struct rspamd_mime_header *cur;
-       GPtrArray *hdrs;
-       guint i;
-       rspamd_stat_token_t str;
-
-       hdrs = g_hash_table_lookup (task->raw_headers, name);
-       str.flags = RSPAMD_STAT_TOKEN_FLAG_META;
-
-       if (hdrs != NULL) {
-
-               PTR_ARRAY_FOREACH (hdrs, i, cur) {
-                       if (cur->name != NULL) {
-                               str.begin = cur->name;
-                               str.len = strlen (cur->name);
-                               g_array_append_val (ar, str);
-                       }
-                       if (cur->decoded != NULL) {
-                               str.begin = cur->decoded;
-                               str.len = strlen (cur->decoded);
-                               g_array_append_val (ar, str);
-                       }
-                       else if (cur->value != NULL) {
-                               str.begin = cur->value;
-                               str.len = strlen (cur->value);
-                               g_array_append_val (ar, str);
-                       }
-               }
-
-               msg_debug_bayes ("added stat tokens for header '%s'", name);
-       }
-}
-
 static void
 rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx,
                struct rspamd_task *task)
 {
-       struct rspamd_image *img;
-       struct rspamd_mime_part *part;
-       struct rspamd_mime_text_part *tp;
-       GList *cur;
        GArray *ar;
        rspamd_stat_token_t elt;
        guint i;
-       gchar tmpbuf[128];
        lua_State *L = task->cfg->lua_state;
-       const gchar *headers_hash;
-       struct rspamd_mime_header *hdr;
 
        ar = g_array_sized_new (FALSE, FALSE, sizeof (elt), 16);
        elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
 
-       /* Insert images */
-       for (i = 0; i < task->parts->len; i ++) {
-               part = g_ptr_array_index (task->parts, i);
-
-               if ((part->flags & RSPAMD_MIME_PART_IMAGE) && part->specific.img) {
-                       img = part->specific.img;
-
-                       /* If an image has a linked HTML part, then we push its details to the stat */
-                       if (img->html_image) {
-                               elt.begin = (gchar *)"image";
-                               elt.len = 5;
-                               g_array_append_val (ar, elt);
-                               elt.begin = (gchar *)&img->html_image->height;
-                               elt.len = sizeof (img->html_image->height);
-                               g_array_append_val (ar, elt);
-                               elt.begin = (gchar *)&img->html_image->width;
-                               elt.len = sizeof (img->html_image->width);
-                               g_array_append_val (ar, elt);
-                               elt.begin = (gchar *)&img->type;
-                               elt.len = sizeof (img->type);
-                               g_array_append_val (ar, elt);
-
-                               if (img->filename) {
-                                       elt.begin = (gchar *)img->filename;
-                                       elt.len = strlen (elt.begin);
-                                       g_array_append_val (ar, elt);
-                               }
+       if (st_ctx->lua_stat_tokens_ref != -1) {
+               gint err_idx, ret;
+               GString *tb;
+               struct rspamd_task **ptask;
 
-                               msg_debug_bayes ("added stat tokens for image '%s'", img->html_image->src);
-                       }
-               }
-               else if (part->cd && part->cd->filename.len > 0) {
-                       elt.begin = (gchar *)part->cd->filename.begin;
-                       elt.len = part->cd->filename.len;
-                       g_array_append_val (ar, elt);
-               }
-       }
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               err_idx = lua_gettop (L);
+               lua_rawgeti (L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref);
 
-       /* Process mime parts */
-       for (i = 0; i < task->parts->len; i ++) {
-               part = g_ptr_array_index (task->parts, i);
+               ptask = lua_newuserdata (L, sizeof (*ptask));
+               *ptask = task;
+               rspamd_lua_setclass (L, "rspamd{task}", -1);
 
-               if (IS_CT_MULTIPART (part->ct)) {
-                       elt.begin = (gchar *)part->ct->boundary.begin;
-                       elt.len = part->ct->boundary.len;
+               if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) {
+                       tb = lua_touserdata (L, -1);
+                       msg_err_task ("call to stat_tokens lua "
+                                                       "script failed (%d): %v", ret, tb);
 
-                       if (elt.len) {
-                               msg_debug_bayes ("added stat tokens for mime boundary '%*s'",
-                                               (gint)elt.len, elt.begin);
-                               g_array_append_val (ar, elt);
-                       }
-
-                       if (part->parsed_data.len > 1) {
-                               rspamd_snprintf (tmpbuf, sizeof (tmpbuf), "mime%d:%dlog",
-                                               i, (gint)log2 (part->parsed_data.len));
-                               elt.begin = rspamd_mempool_strdup (task->task_pool, tmpbuf);
-                               elt.len = strlen (elt.begin);
-                               g_array_append_val (ar, elt);
+                       if (tb) {
+                               g_string_free (tb, TRUE);
                        }
                }
-       }
+               else {
+                       if (lua_type (L, -1) != LUA_TTABLE) {
+                               msg_err_task ("stat_tokens invocation must return "
+                                                               "table and not %s",
+                                               lua_typename (L, lua_type (L, -1)));
+                       }
+                       else {
+                               guint vlen;
+                               rspamd_ftok_t tok;
 
-       /* Process text parts metadata */
-       for (i = 0; i < task->text_parts->len; i ++) {
-               tp = g_ptr_array_index (task->text_parts, i);
+                               vlen = rspamd_lua_table_size (L, -1);
 
-               if (tp->language != NULL && tp->language[0] != '\0') {
-                       elt.begin = (gchar *)tp->language;
-                       elt.len = strlen (elt.begin);
-                       msg_debug_bayes ("added stat tokens for part language '%s'", elt.begin);
-                       g_array_append_val (ar, elt);
-               }
-               if (tp->real_charset != NULL) {
-                       elt.begin = (gchar *)tp->real_charset;
-                       elt.len = strlen (elt.begin);
-                       msg_debug_bayes ("added stat tokens for part charset '%s'", elt.begin);
-                       g_array_append_val (ar, elt);
-               }
-       }
+                               for (i = 0; i < vlen; i ++) {
+                                       lua_rawgeti (L, -1, i + 1);
+                                       tok.begin = lua_tolstring (L, -1, &tok.len);
+
+                                       if (tok.begin && tok.len > 0) {
+                                               elt.begin = rspamd_mempool_ftokdup (task->task_pool, &tok);
+                                               elt.len = tok.len;
 
-       cur = g_list_first (task->cfg->classify_headers);
+                                               g_array_append_val (ar, elt);
+                                       }
 
-       while (cur) {
-               rspamd_stat_tokenize_header (task, cur->data, "UA:", ar);
+                                       lua_pop (L, 1);
+                               }
+                       }
+               }
 
-               cur = g_list_next (cur);
+               lua_settop (L, 0);
        }
 
-       /* Use headers order */
-       headers_hash = rspamd_mempool_get_variable (task->task_pool,
-                       RSPAMD_MEMPOOL_HEADERS_HASH);
 
-       if (headers_hash) {
-               elt.begin = (gchar *)headers_hash;
-               elt.len = 16;
-               g_array_append_val (ar, elt);
+       if (ar->len > 0) {
+               st_ctx->tokenizer->tokenize_func (st_ctx,
+                               task,
+                               ar,
+                               TRUE,
+                               "M",
+                               task->tokens);
        }
 
        rspamd_mempool_add_destructor (task->task_pool,