]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Rework] Refactor extract_specific_urls to prevent DoS and use hash-based deduplication 5732/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 7 Nov 2025 16:17:20 +0000 (16:17 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 7 Nov 2025 16:17:20 +0000 (16:17 +0000)
Replace tostring() with url:get_hash() throughout URL extraction to avoid
filling the Lua string interning table. Critical for handling malicious
messages with 100k+ URLs where each tostring() would create an interned
string causing memory exhaustion.

Key changes:
- Use dual data structure: array for results + hash set for O(1) dedup
- Add max_urls_to_process=50000 limit with warning for DoS protection
- Track url_index for stable sorting when priorities are equal
- Fix CTA priority preservation: prevent generic phished handling from
  overwriting CTA priorities which include phished/subject bonuses
- Add verbose flag to test suite for debugging

This ensures memory usage is strictly bounded regardless of malicious
input while maintaining correct URL prioritization for spam detection.

lualib/lua_util.lua
test/rspamd_test_suite.c

index 4851e566fb8cd3b2342c5da0a151b4ccef5d6423..a8d40c876384889bbfccc045d776d379738b99ab 100644 (file)
@@ -875,12 +875,30 @@ exports.filter_specific_urls = function(urls, params)
     urls = fun.totable(fun.filter(params.filter, urls))
   end
 
+  -- Memory bounds protection against DoS: absolute maximum of URLs to process
+  -- Even if we have 100k URLs of 1KB each in input, we process at most first 50k
+  -- This prevents excessive memory usage and CPU time for malicious messages
+  local max_urls_to_process = 50000
+  if #urls > max_urls_to_process then
+    local logger = require "rspamd_logger"
+    logger.warnx(params.task, 'too many URLs to process: %d, limiting to %d',
+      #urls, max_urls_to_process)
+    -- Truncate the urls table
+    local truncated = {}
+    for i = 1, max_urls_to_process do
+      truncated[i] = urls[i]
+    end
+    urls = truncated
+  end
+
   -- Filter by tld:
   local tlds = {}
   local eslds = {}
   local ntlds, neslds = 0, 0
 
-  local res = {}
+  -- Use two structures: hash set for deduplication, array for results
+  local res = {}        -- array of URLs (maintains order)
+  local seen = {}       -- hash set for deduplication (hash -> true)
   local nres = 0
 
   local cta_priority_map
@@ -896,17 +914,22 @@ exports.filter_specific_urls = function(urls, params)
             for _, entry in ipairs(entries) do
               if entry and entry.url then
                 local url = entry.url
-                local str = tostring(url)
+                -- Use hash instead of tostring to avoid string interning
+                local hash = url:get_hash()
                 local weight = entry.weight or 0
                 local score = 6 + math.floor(weight * 10 + 0.5)
-                if not cta_priority_map[str] or score > cta_priority_map[str] then
-                  cta_priority_map[str] = score
+                if not cta_priority_map[hash] or score > cta_priority_map[hash] then
+                  cta_priority_map[hash] = score
                 end
                 local redir = url:get_redirected()
                 if redir then
-                  local rstr = tostring(redir)
-                  if not cta_priority_map[rstr] or score > cta_priority_map[rstr] then
-                    cta_priority_map[rstr] = score
+                  -- Skip display-only URLs (phishing bait text) - only include real redirects
+                  local redir_flags = redir:get_flags()
+                  if not redir_flags.html_displayed then
+                    local rhash = redir:get_hash()
+                    if not cta_priority_map[rhash] or score > cta_priority_map[rhash] then
+                      cta_priority_map[rhash] = score
+                    end
                   end
                 end
               end
@@ -921,9 +944,10 @@ exports.filter_specific_urls = function(urls, params)
     end
   end
 
-  local function insert_url(str, u)
-    if not res[str] then
-      res[str] = u
+  local function insert_url(hash, u)
+    if not seen[hash] then
+      seen[hash] = true
+      table.insert(res, u)
       nres = nres + 1
 
       return true
@@ -932,7 +956,9 @@ exports.filter_specific_urls = function(urls, params)
     return false
   end
 
+  local url_index = 0  -- Track URL processing order for stable sorting
   local function process_single_url(u, default_priority)
+    url_index = url_index + 1
     local priority = default_priority or 1 -- Normal priority
     local flags = u:get_flags()
     if params.ignore_ip and flags.numeric then
@@ -963,44 +989,58 @@ exports.filter_specific_urls = function(urls, params)
     end
 
     local esld = u:get_tld()
-    local str_hash = tostring(u)
+    -- Use fast hash instead of tostring to avoid string interning
+    -- This is critical for DoS protection with 100k+ URLs
+    local url_hash = u:get_hash()
 
     if cta_priority_map then
-      local cta_pr = cta_priority_map[str_hash]
+      local cta_pr = cta_priority_map[url_hash]
       if not cta_pr and flags.redirected then
         local redir_url = u:get_redirected()
         if redir_url then
-          cta_pr = cta_priority_map[tostring(redir_url)]
+          cta_pr = cta_priority_map[redir_url:get_hash()]
         end
       end
 
       if cta_pr then
-        priority = math.max(priority, cta_pr)
+        -- Use CTA priority as base, but add phished/subject bonuses on top
+        priority = cta_pr
+        if flags.phished then
+          priority = priority + 5  -- Phished URLs get significant extra priority (security critical)
+        end
+        if flags.subject then
+          priority = priority + 2  -- Subject URLs get extra priority
+        end
       end
     end
 
     if esld then
-      -- Special cases
-      if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then
-        if flags.obscured then
-          priority = 3
-        else
-          if (flags.has_user or flags.has_port) then
-            priority = 2
-          elseif (flags.subject or flags.phished) then
-            priority = 2
+      -- Special cases (only apply if CTA priority wasn't set)
+      -- CTA priority takes precedence as it's explicitly calculated with bonuses
+      local had_cta = priority ~= (default_priority or 1)
+
+      if not had_cta then
+        if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then
+          if flags.obscured then
+            priority = 3
+          else
+            if (flags.has_user or flags.has_port) then
+              priority = 2
+            elseif (flags.subject or flags.phished) then
+              priority = 2
+            end
           end
+        elseif flags.html_displayed then
+          priority = 0
         end
-      elseif flags.html_displayed then
-        priority = 0
       end
 
       if not eslds[esld] then
-        eslds[esld] = { { str_hash, u, priority } }
+        eslds[esld] = { { url_hash, u, priority, url_index } }
         neslds = neslds + 1
       else
         if #eslds[esld] < params.esld_limit then
-          table.insert(eslds[esld], { str_hash, u, priority })
+          table.insert(eslds[esld], { url_hash, u, priority, url_index })
         end
       end
 
@@ -1010,10 +1050,10 @@ exports.filter_specific_urls = function(urls, params)
       local tld = table.concat(fun.totable(fun.tail(parts)), '.')
 
       if not tlds[tld] then
-        tlds[tld] = { { str_hash, u, priority } }
+        tlds[tld] = { { url_hash, u, priority, url_index } }
         ntlds = ntlds + 1
       else
-        table.insert(tlds[tld], { str_hash, u, priority })
+        table.insert(tlds[tld], { url_hash, u, priority, url_index })
       end
     end
   end
@@ -1029,7 +1069,6 @@ exports.filter_specific_urls = function(urls, params)
   end
 
   if limit == 0 then
-    res = exports.values(res)
     if params.task and not params.no_cache then
       params.task:cache_set(cache_key, res)
     end
@@ -1040,20 +1079,32 @@ exports.filter_specific_urls = function(urls, params)
   local function sort_stuff(tbl)
     -- Sort according to max priority
     table.sort(tbl, function(e1, e2)
-      -- Sort by priority so max priority is at the end
+      -- Sort by priority (desc) then by url_index (asc) for stable sorting
       table.sort(e1, function(tr1, tr2)
-        return tr1[3] < tr2[3]
+        if tr1[3] ~= tr2[3] then
+          return tr1[3] < tr2[3]  -- Lower priority first
+        else
+          return tr1[4] < tr2[4]  -- Earlier URL index first (stable)
+        end
       end)
       table.sort(e2, function(tr1, tr2)
-        return tr1[3] < tr2[3]
+        if tr1[3] ~= tr2[3] then
+          return tr1[3] < tr2[3]
+        else
+          return tr1[4] < tr2[4]
+        end
       end)
 
       if e1[#e1][3] ~= e2[#e2][3] then
         -- Sort by priority so max priority is at the beginning
         return e1[#e1][3] > e2[#e2][3]
       else
-        -- Prefer less urls to more urls per esld
-        return #e1 < #e2
+        -- Prefer less urls to more urls per esld, then earlier index
+        if #e1 ~= #e2 then
+          return #e1 < #e2
+        else
+          return e1[#e1][4] < e2[#e2][4]  -- Use index as final tiebreaker
+        end
       end
     end)
 
@@ -1078,7 +1129,6 @@ exports.filter_specific_urls = function(urls, params)
       end
     until limit <= 0 or not item_found
 
-    res = exports.values(res)
     if params.task and not params.no_cache then
       params.task:cache_set(cache_key, res)
     end
@@ -1102,7 +1152,6 @@ exports.filter_specific_urls = function(urls, params)
     end
   end
 
-  res = exports.values(res)
   if params.task and not params.no_cache then
     params.task:cache_set(cache_key, res)
   end
index 7648cddd457a44260e93adb046701b8567a61048..e8264045d9eb7590b5eac11d4f691db3f17a181e 100644 (file)
@@ -20,6 +20,8 @@ static GOptionEntry entries[] =
                 "Lua test to run (i.e. selectors.lua)", NULL},
                {"test-case", 'c', 0, G_OPTION_ARG_STRING, &lua_test_case,
                 "Lua test to run, lua pattern i.e. \"case .* rcpts\"", NULL},
+               {"verbose", 'v', 0, G_OPTION_ARG_NONE, &verbose,
+                "Enable verbose output", NULL},
                {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}};
 
 int main(int argc, char **argv)