]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add task:get_html_urls() for async URL rewriting
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 13 Oct 2025 10:46:09 +0000 (11:46 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 13 Oct 2025 11:37:57 +0000 (12:37 +0100)
Introduce a two-phase API for HTML URL rewriting that separates URL
extraction from the rewriting step. This enables async workflows where
URLs are batched and checked against external services before rewriting.

Changes:
- Add rspamd_html_enumerate_urls() C wrapper to extract URL candidates
- Add task:get_html_urls() Lua method returning URL info per HTML part
- Include comprehensive unit tests covering edge cases
- Provide async usage examples (HTTP, Redis, simple patterns)

The new API complements the existing task:rewrite_html_urls() method,
allowing users to extract URLs, perform async operations, then apply
rewrites using a lookup table callback.

src/libserver/html/html_url_rewrite.cxx
src/libserver/html/html_url_rewrite_c.cxx
src/libserver/html/html_url_rewrite_c.h
src/lua/lua_task.c
test/lua/unit/get_html_urls.lua [new file with mode: 0644]
test/lua/unit/get_html_urls_async_example.lua [new file with mode: 0644]
test/lua/unit/url_rewrite.lua

index 387490206eb5e1693c538601828abfd1b6a34cea..5375f92964cbaaaf39b1ea9d50cc85c534fc12a2 100644 (file)
@@ -123,10 +123,8 @@ auto enumerate_rewrite_candidates(const html_content *hc, struct rspamd_task *ta
                }
 
                // Skip data: and cid: schemes by default
-               if (url_value.size() >= 5) {
-                       if (url_value.substr(0, 5) == "data:" || url_value.substr(0, 4) == "cid:") {
-                               return true;// Continue to next
-                       }
+               if (url_value.starts_with("data:") || url_value.starts_with("cid:")) {
+                       return true;// Continue to next
                }
 
                // Build absolute URL (already done by parser, but we have it in url_value)
index f3e66672fe8ce4036b01b62cd469e5e5f5ede0a0..5f06a459f1fa557b67976cc8d86c1efa357555d9 100644 (file)
 
 extern "C" {
 
+int rspamd_html_enumerate_urls(struct rspamd_task *task,
+                                                          void *html_content,
+                                                          int part_id,
+                                                          struct rspamd_html_url_candidate **candidates,
+                                                          gsize *n_candidates)
+{
+       if (!task || !html_content || !candidates || !n_candidates) {
+               return -1;
+       }
+
+       auto *hc = static_cast<const rspamd::html::html_content *>(html_content);
+
+       // Enumerate candidates using C++ function
+       auto cpp_candidates = rspamd::html::enumerate_rewrite_candidates(hc, task, part_id);
+
+       if (cpp_candidates.empty()) {
+               *candidates = nullptr;
+               *n_candidates = 0;
+               return 0;
+       }
+
+       // Allocate C-style array from task pool
+       *n_candidates = cpp_candidates.size();
+       *candidates = (struct rspamd_html_url_candidate *) rspamd_mempool_alloc(
+               task->task_pool,
+               sizeof(struct rspamd_html_url_candidate) * cpp_candidates.size());
+
+       // Convert C++ candidates to C candidates
+       for (size_t i = 0; i < cpp_candidates.size(); i++) {
+               const auto &cpp_cand = cpp_candidates[i];
+
+               // Allocate strings from task pool
+               char *url_str = (char *) rspamd_mempool_alloc(
+                       task->task_pool,
+                       cpp_cand.absolute_url.size() + 1);
+               memcpy(url_str, cpp_cand.absolute_url.data(), cpp_cand.absolute_url.size());
+               url_str[cpp_cand.absolute_url.size()] = '\0';
+
+               char *attr_str = (char *) rspamd_mempool_alloc(
+                       task->task_pool,
+                       cpp_cand.attr_name.size() + 1);
+               memcpy(attr_str, cpp_cand.attr_name.data(), cpp_cand.attr_name.size());
+               attr_str[cpp_cand.attr_name.size()] = '\0';
+
+               // Get tag name
+               const char *tag_name = "unknown";
+               gsize tag_len = 7;
+               if (cpp_cand.tag) {
+                       // Use rspamd_html_tag_by_id which returns const char*
+                       extern const char *rspamd_html_tag_by_id(int id);
+                       tag_name = rspamd_html_tag_by_id(cpp_cand.tag->id);
+                       if (tag_name) {
+                               tag_len = strlen(tag_name);
+                       }
+                       else {
+                               tag_name = "unknown";
+                               tag_len = 7;
+                       }
+               }
+
+               (*candidates)[i].url = url_str;
+               (*candidates)[i].url_len = cpp_cand.absolute_url.size();
+               (*candidates)[i].attr = attr_str;
+               (*candidates)[i].attr_len = cpp_cand.attr_name.size();
+               (*candidates)[i].tag = tag_name;
+               (*candidates)[i].tag_len = tag_len;
+       }
+
+       return 0;
+}
+
 int rspamd_html_url_rewrite(struct rspamd_task *task,
                                                        struct lua_State *L,
                                                        void *html_content,
index 798c8b39872f2f43517daa26e8c615a4051d49a1..c0906e00a7d15f435d045dfc83d4fbc405c00116 100644 (file)
@@ -27,6 +27,33 @@ struct rspamd_task;
 
 struct lua_State;
 
+/**
+ * URL candidate info for C interface
+ */
+struct rspamd_html_url_candidate {
+       const char *url; // Absolute URL string (NUL-terminated)
+       const char *attr;// Attribute name: "href" or "src" (NUL-terminated)
+       const char *tag; // Tag name (NUL-terminated)
+       gsize url_len;   // Length of URL string
+       gsize attr_len;  // Length of attr string
+       gsize tag_len;   // Length of tag string
+};
+
+/**
+ * C wrapper for enumerating HTML URL rewrite candidates
+ * @param task Rspamd task
+ * @param html_content HTML content pointer (void* cast of html_content*)
+ * @param part_id MIME part ID
+ * @param candidates Output array of candidates (allocated from task pool if successful)
+ * @param n_candidates Output count of candidates
+ * @return 0 on success, -1 on error
+ */
+int rspamd_html_enumerate_urls(struct rspamd_task *task,
+                                                          void *html_content,
+                                                          int part_id,
+                                                          struct rspamd_html_url_candidate **candidates,
+                                                          gsize *n_candidates);
+
 /**
  * C wrapper for HTML URL rewriting
  * @param task Rspamd task
index a111ef97299273bb708f1b199853fb9f4bf21a94..5f0295268cbf74ddc9725b2661b7ce18a23afc8b 100644 (file)
@@ -1285,6 +1285,14 @@ LUA_FUNCTION_DEF(task, add_timer);
  */
 LUA_FUNCTION_DEF(task, rewrite_html_urls);
 
+/***
+ * @method task:get_html_urls()
+ * Extracts all URLs from HTML parts without rewriting.
+ * Useful for async URL checking workflows where URLs need to be batched.
+ * @return {table|nil} table indexed by part number, each containing an array of URL info tables with keys: url, attr, tag
+ */
+LUA_FUNCTION_DEF(task, get_html_urls);
+
 static const struct luaL_reg tasklib_f[] = {
        LUA_INTERFACE_DEF(task, create),
        LUA_INTERFACE_DEF(task, load_from_file),
@@ -1416,6 +1424,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF(task, topointer),
        LUA_INTERFACE_DEF(task, add_timer),
        LUA_INTERFACE_DEF(task, rewrite_html_urls),
+       LUA_INTERFACE_DEF(task, get_html_urls),
        {"__tostring", rspamd_lua_class_tostring},
        {NULL, NULL}};
 
@@ -7875,6 +7884,88 @@ lua_task_rewrite_html_urls(lua_State *L)
        return 1;
 }
 
+static int
+lua_task_get_html_urls(lua_State *L)
+{
+       struct rspamd_task *task = lua_check_task(L, 1);
+
+       if (!task || !MESSAGE_FIELD_CHECK(task, text_parts)) {
+               lua_pushnil(L);
+               return 1;
+       }
+
+       /* Create result table */
+       lua_newtable(L);
+       int results = 0;
+       unsigned int i;
+       void *part;
+
+       /* Iterate through text parts */
+       PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
+       {
+               struct rspamd_mime_text_part *text_part = (struct rspamd_mime_text_part *) part;
+
+               /* Only process HTML parts */
+               if (!IS_TEXT_PART_HTML(text_part) || !text_part->html) {
+                       continue;
+               }
+
+               /* Skip if no UTF-8 content available */
+               if (!text_part->utf_raw_content || text_part->utf_raw_content->len == 0) {
+                       continue;
+               }
+
+               struct rspamd_html_url_candidate *candidates = NULL;
+               gsize n_candidates = 0;
+
+               /* Enumerate URLs using C wrapper */
+               int ret = rspamd_html_enumerate_urls(
+                       task,
+                       text_part->html,
+                       text_part->mime_part->part_number,
+                       &candidates,
+                       &n_candidates);
+
+               if (ret == 0 && candidates && n_candidates > 0) {
+                       /* Create array for this part: table[part_number] = {url_info_1, url_info_2, ...} */
+                       lua_pushinteger(L, text_part->mime_part->part_number);
+                       lua_newtable(L); /* URLs array for this part */
+
+                       for (gsize j = 0; j < n_candidates; j++) {
+                               lua_pushinteger(L, j + 1); /* 1-indexed array */
+                               lua_newtable(L);           /* URL info table */
+
+                               /* url field */
+                               lua_pushstring(L, "url");
+                               lua_pushstring(L, candidates[j].url);
+                               lua_settable(L, -3);
+
+                               /* attr field */
+                               lua_pushstring(L, "attr");
+                               lua_pushstring(L, candidates[j].attr);
+                               lua_settable(L, -3);
+
+                               /* tag field */
+                               lua_pushstring(L, "tag");
+                               lua_pushstring(L, candidates[j].tag);
+                               lua_settable(L, -3);
+
+                               lua_settable(L, -3); /* Add url info to URLs array */
+                       }
+
+                       lua_settable(L, -3); /* Add part to main table */
+                       results++;
+               }
+       }
+
+       if (results == 0) {
+               lua_pop(L, 1);
+               lua_pushnil(L);
+       }
+
+       return 1;
+}
+
 /* Init part */
 
 static int
diff --git a/test/lua/unit/get_html_urls.lua b/test/lua/unit/get_html_urls.lua
new file mode 100644 (file)
index 0000000..862fea4
--- /dev/null
@@ -0,0 +1,338 @@
+context("HTML URL extraction", function()
+  local rspamd_task = require("rspamd_task")
+  local logger = require("rspamd_logger")
+
+  test("Basic URL extraction from simple HTML", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<a href="http://example.com/test">Click here</a>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should extract URLs")
+
+    -- Check structure
+    local found_url = false
+    for part_id, url_list in pairs(urls) do
+      assert_true(type(url_list) == "table", "URL list should be a table")
+      for i, url_info in ipairs(url_list) do
+        assert_not_nil(url_info.url, "should have url field")
+        assert_not_nil(url_info.attr, "should have attr field")
+        assert_not_nil(url_info.tag, "should have tag field")
+
+        if url_info.url == "http://example.com/test" then
+          assert_equal(url_info.attr, "href", "should be href attribute")
+          assert_equal(url_info.tag, "a", "should be <a> tag")
+          found_url = true
+        end
+      end
+    end
+
+    assert_true(found_url, "should find the expected URL")
+
+    task:destroy()
+  end)
+
+  test("Multiple URLs in same HTML part", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<a href="http://example.com/link1">Link 1</a>
+<a href="http://example.com/link2">Link 2</a>
+<img src="http://example.com/image.jpg">
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should extract URLs")
+
+    -- Count URLs
+    local url_count = 0
+    local found_urls = {}
+    for part_id, url_list in pairs(urls) do
+      for i, url_info in ipairs(url_list) do
+        url_count = url_count + 1
+        found_urls[url_info.url] = url_info
+      end
+    end
+
+    assert_equal(url_count, 3, "should have found 3 URLs")
+    assert_not_nil(found_urls["http://example.com/link1"], "should find link1")
+    assert_not_nil(found_urls["http://example.com/link2"], "should find link2")
+    assert_not_nil(found_urls["http://example.com/image.jpg"], "should find image")
+
+    -- Check attributes
+    assert_equal(found_urls["http://example.com/link1"].attr, "href")
+    assert_equal(found_urls["http://example.com/link1"].tag, "a")
+    assert_equal(found_urls["http://example.com/image.jpg"].attr, "src")
+    assert_equal(found_urls["http://example.com/image.jpg"].tag, "img")
+
+    task:destroy()
+  end)
+
+  test("Non-HTML parts return nil", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/plain
+
+This is plain text with http://example.com/test
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    -- Should return nil for plain text
+    assert_nil(urls, "should return nil for non-HTML parts")
+
+    task:destroy()
+  end)
+
+  test("Empty HTML returns nil", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    -- Should return nil for empty HTML
+    assert_nil(urls, "should return nil for empty HTML")
+
+    task:destroy()
+  end)
+
+  test("HTML without URLs returns nil", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<p>Just some text without any links</p>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    -- Should return nil when no URLs found
+    assert_nil(urls, "should return nil when no URLs found")
+
+    task:destroy()
+  end)
+
+  test("Data URI scheme is skipped", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<img src="">
+<a href="http://example.com/test">Real link</a>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should extract non-data URLs")
+
+    -- Check that data: URIs are skipped
+    local found_data_uri = false
+    local found_http_url = false
+    for part_id, url_list in pairs(urls) do
+      for i, url_info in ipairs(url_list) do
+        if url_info.url:find("^data:", 1, false) then
+          found_data_uri = true
+        end
+        if url_info.url == "http://example.com/test" then
+          found_http_url = true
+        end
+      end
+    end
+
+    assert_false(found_data_uri, "data: URIs should be skipped")
+    assert_true(found_http_url, "should have found the http URL")
+
+    task:destroy()
+  end)
+
+  test("CID scheme is skipped", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<img src="cid:image001@example.com">
+<a href="http://example.com/test">Real link</a>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should extract non-cid URLs")
+
+    -- Check that cid: URIs are skipped
+    local found_cid_uri = false
+    local found_http_url = false
+    for part_id, url_list in pairs(urls) do
+      for i, url_info in ipairs(url_list) do
+        if url_info.url:find("^cid:", 1, false) then
+          found_cid_uri = true
+        end
+        if url_info.url == "http://example.com/test" then
+          found_http_url = true
+        end
+      end
+    end
+
+    assert_false(found_cid_uri, "cid: URIs should be skipped")
+    assert_true(found_http_url, "should have found the http URL")
+
+    task:destroy()
+  end)
+
+  test("Multipart message with multiple HTML parts", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: multipart/alternative; boundary="boundary123"
+
+--boundary123
+Content-Type: text/plain
+
+Plain text part
+
+--boundary123
+Content-Type: text/html
+
+<html><body><a href="http://example.com/part1">Part 1</a></body></html>
+
+--boundary123
+Content-Type: text/html
+
+<html><body><a href="http://example.com/part2">Part 2</a></body></html>
+
+--boundary123--
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should extract URLs from multipart HTML")
+
+    -- Should have processed at least one HTML part
+    local part_count = 0
+    local total_urls = 0
+    for part_id, url_list in pairs(urls) do
+      part_count = part_count + 1
+      total_urls = total_urls + #url_list
+    end
+
+    assert_true(part_count >= 1, "should have URLs from at least one HTML part")
+    assert_true(total_urls >= 1, "should have found at least one URL")
+
+    task:destroy()
+  end)
+
+  test("URL with special characters", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<a href="http://example.com/path?param=value&other=123#anchor">Link</a>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls = task:get_html_urls()
+
+    assert_not_nil(urls, "should handle URLs with special chars")
+
+    local found_url = false
+    for part_id, url_list in pairs(urls) do
+      for i, url_info in ipairs(url_list) do
+        if url_info.url:find("example.com/path", 1, true) then
+          found_url = true
+          -- URL should contain the query parameters
+          assert_true(url_info.url:find("param=value", 1, true) ~= nil,
+                     "should preserve query parameters")
+        end
+      end
+    end
+
+    assert_true(found_url, "should have found the URL with special chars")
+
+    task:destroy()
+  end)
+
+end)
diff --git a/test/lua/unit/get_html_urls_async_example.lua b/test/lua/unit/get_html_urls_async_example.lua
new file mode 100644 (file)
index 0000000..90d44b6
--- /dev/null
@@ -0,0 +1,273 @@
+--[[
+  Async HTML URL Rewriting Example
+
+  This is an example demonstrating how to use task:get_html_urls() with
+  async operations to batch-check URLs against an external service before
+  rewriting them.
+
+  Usage pattern:
+  1. Extract all URLs from HTML parts using task:get_html_urls()
+  2. Send all URLs to external service via async HTTP/Redis/etc
+  3. Receive URL replacements from service
+  4. Apply rewrites using task:rewrite_html_urls() with lookup table
+]]
+
+-- Example rule implementation
+local function register_async_url_rewriter(rspamd_config)
+  rspamd_config:register_symbol({
+    name = 'ASYNC_URL_REWRITER',
+    type = 'postfilter',
+    callback = function(task)
+      -- Step 1: Extract all URLs from HTML parts
+      local urls_by_part = task:get_html_urls()
+
+      if not urls_by_part then
+        return -- No HTML URLs to process
+      end
+
+      -- Flatten URLs for batched API request
+      local all_urls = {}
+      local url_to_info = {}
+
+      for part_id, url_list in pairs(urls_by_part) do
+        for _, url_info in ipairs(url_list) do
+          table.insert(all_urls, url_info.url)
+          url_to_info[url_info.url] = url_info
+        end
+      end
+
+      if #all_urls == 0 then
+        return
+      end
+
+      rspamd_logger.infox(task, "Found %s HTML URLs to check", #all_urls)
+
+      -- Step 2: Make async request to URL checking service
+      local http = require "rspamd_http"
+      local ucl = require "ucl"
+
+      http.request({
+        task = task,
+        url = 'http://url-checker.example.com/api/check-batch',
+        callback = function(err, code, body)
+          if err then
+            rspamd_logger.errx(task, 'URL check failed: %s', err)
+            return
+          end
+
+          if code ~= 200 then
+            rspamd_logger.errx(task, 'URL check service returned HTTP %s', code)
+            return
+          end
+
+          -- Step 3: Parse response containing URL replacements
+          local parser = ucl.parser()
+          local ok, parse_err = parser:parse_string(body)
+
+          if not ok then
+            rspamd_logger.errx(task, 'Failed to parse response: %s', parse_err)
+            return
+          end
+
+          local response = parser:get_object()
+
+          -- Build replacement map: original_url -> new_url
+          local replacements = {}
+
+          for original_url, result in pairs(response.urls or {}) do
+            if result.action == 'rewrite' and result.new_url then
+              replacements[original_url] = result.new_url
+              rspamd_logger.infox(task, "Will rewrite %s -> %s",
+                                 original_url, result.new_url)
+            elseif result.action == 'block' then
+              -- Redirect blocked URLs to warning page
+              replacements[original_url] = 'https://warning.example.com/blocked'
+              rspamd_logger.infox(task, "Blocking URL %s", original_url)
+
+              -- Optionally set a symbol
+              task:insert_result('BLOCKED_URL', 1.0, original_url)
+            end
+          end
+
+          -- Step 4: Apply rewrites using lookup table callback
+          if next(replacements) then
+            local rewritten = task:rewrite_html_urls(function(task, url)
+              -- Simple lookup - returns nil if URL shouldn't be rewritten
+              return replacements[url]
+            end)
+
+            if rewritten then
+              rspamd_logger.infox(task, 'Rewritten URLs in parts: %s',
+                                 table.concat(table_keys(rewritten), ', '))
+
+              -- Optionally set a symbol to track rewrites
+              task:insert_result('URL_REWRITTEN', 1.0,
+                                string.format('%d URLs', count_rewrites(replacements)))
+            end
+          end
+        end,
+
+        -- Request configuration
+        headers = {
+          ['Content-Type'] = 'application/json',
+          ['Authorization'] = 'Bearer YOUR_API_TOKEN'
+        },
+        body = ucl.to_format({
+          urls = all_urls,
+          -- Include additional context if needed
+          message_id = task:get_message_id(),
+          from = (task:get_from('smtp') or {})[1]
+        }, 'json'),
+        timeout = 5.0
+      })
+    end,
+    priority = 10 -- Postfilter priority
+  })
+end
+
+-- Helper functions
+local function table_keys(t)
+  local keys = {}
+  for k, _ in pairs(t) do
+    table.insert(keys, tostring(k))
+  end
+  return keys
+end
+
+local function count_rewrites(replacements)
+  local count = 0
+  for _, _ in pairs(replacements) do
+    count = count + 1
+  end
+  return count
+end
+
+--[[
+  Alternative: Using Redis for caching URL check results
+]]
+
+local function register_redis_cached_url_rewriter(rspamd_config)
+  rspamd_config:register_symbol({
+    name = 'REDIS_CACHED_URL_REWRITER',
+    type = 'postfilter',
+    callback = function(task)
+      local redis = require "rspamd_redis"
+      local urls_by_part = task:get_html_urls()
+
+      if not urls_by_part then
+        return
+      end
+
+      -- Collect all URLs
+      local all_urls = {}
+      for part_id, url_list in pairs(urls_by_part) do
+        for _, url_info in ipairs(url_list) do
+          table.insert(all_urls, url_info.url)
+        end
+      end
+
+      if #all_urls == 0 then
+        return
+      end
+
+      -- Build Redis MGET command to check all URLs at once
+      local redis_keys = {}
+      for _, url in ipairs(all_urls) do
+        table.insert(redis_keys, 'url:rewrite:' .. url)
+      end
+
+      redis.make_request({
+        task = task,
+        cmd = 'MGET',
+        args = redis_keys,
+        callback = function(err, data)
+          if err then
+            rspamd_logger.errx(task, 'Redis error: %s', err)
+            return
+          end
+
+          -- Build replacement map from Redis results
+          local replacements = {}
+          for i, url in ipairs(all_urls) do
+            if data[i] and data[i] ~= '' then
+              replacements[url] = data[i]
+            end
+          end
+
+          -- Apply rewrites
+          if next(replacements) then
+            local rewritten = task:rewrite_html_urls(function(task, url)
+              return replacements[url]
+            end)
+
+            if rewritten then
+              rspamd_logger.infox(task, 'Applied %d URL rewrites from Redis',
+                                 count_rewrites(replacements))
+            end
+          end
+        end
+      })
+    end
+  })
+end
+
+--[[
+  Simpler example: Rewrite specific domains without external service
+]]
+
+local function register_simple_domain_rewriter(rspamd_config)
+  -- Mapping of domains to redirect targets
+  local domain_redirects = {
+    ['evil.com'] = 'https://warning.example.com/blocked?domain=evil.com',
+    ['phishing.net'] = 'https://warning.example.com/blocked?domain=phishing.net',
+  }
+
+  rspamd_config:register_symbol({
+    name = 'SIMPLE_DOMAIN_REWRITER',
+    type = 'postfilter',
+    callback = function(task)
+      local urls_by_part = task:get_html_urls()
+
+      if not urls_by_part then
+        return
+      end
+
+      -- Check if any URLs match blocked domains
+      local needs_rewrite = false
+      for part_id, url_list in pairs(urls_by_part) do
+        for _, url_info in ipairs(url_list) do
+          for blocked_domain, _ in pairs(domain_redirects) do
+            if url_info.url:find(blocked_domain, 1, true) then
+              needs_rewrite = true
+              break
+            end
+          end
+        end
+      end
+
+      if not needs_rewrite then
+        return
+      end
+
+      -- Apply rewrites
+      local rewritten = task:rewrite_html_urls(function(task, url)
+        for blocked_domain, redirect_url in pairs(domain_redirects) do
+          if url:find(blocked_domain, 1, true) then
+            return redirect_url
+          end
+        end
+        return nil -- Don't rewrite
+      end)
+
+      if rewritten then
+        task:insert_result('DOMAIN_REWRITTEN', 1.0)
+      end
+    end
+  })
+end
+
+return {
+  register_async_url_rewriter = register_async_url_rewriter,
+  register_redis_cached_url_rewriter = register_redis_cached_url_rewriter,
+  register_simple_domain_rewriter = register_simple_domain_rewriter,
+}
index bada63a92d80d4bf835f76f000d5b35aa3b43d25..099f9d41ea8bd4196a673141bdb0edcd0999a945 100644 (file)
@@ -460,4 +460,41 @@ Content-Type: text/html
     task:destroy()
   end)
 
+  test("Edge case: bare cid: and data: schemes", function()
+    local msg = [[
+From: test@example.com
+To: nobody@example.com
+Subject: test
+Content-Type: text/html
+
+<html>
+<body>
+<img src="cid:">
+<img src="data:">
+<a href="http://example.com/test">Real link</a>
+</body>
+</html>
+]]
+    local res, task = rspamd_task.load_from_string(msg, rspamd_config)
+    assert_true(res, "failed to load message")
+
+    task:process_message()
+
+    local urls_seen = {}
+    local function rewrite_callback(task, url)
+      table.insert(urls_seen, url)
+      return "http://safe.com/redirect"
+    end
+
+    local result = task:rewrite_html_urls(rewrite_callback)
+
+    assert_not_nil(result, "should rewrite non-special scheme URLs")
+
+    -- Should only see the http URL, not bare cid: or data:
+    assert_equal(#urls_seen, 1, "should see exactly 1 URL (the http one)")
+    assert_equal(urls_seen[1], "http://example.com/test", "should see the http URL")
+
+    task:destroy()
+  end)
+
 end)