]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Fix strings processing
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 15 Nov 2025 13:54:49 +0000 (13:54 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 15 Nov 2025 13:54:49 +0000 (13:54 +0000)
lualib/lua_url_filter.lua
src/libmime/email_addr.c
src/libserver/composites/composites.cxx
src/libserver/logger/logger.c
src/libserver/logger/logger_syslog.c
src/libutil/str_util.c
src/libutil/str_util.h
src/libutil/upstream.c
src/lua/lua_text.c
src/lua/lua_url.c
test/lua/unit/lua_url_filter.lua

index 7752d816589d3108bb5ca5cef954682080be3cde..1e8b0ba2afcf3876e474e88fa482b165f8dd193e 100644 (file)
@@ -36,6 +36,12 @@ function exports.register_filter(filter_func)
   table.insert(custom_filters, filter_func)
 end
 
+---
+-- Clear all custom filters (mainly for testing)
+function exports.clear_filters()
+  custom_filters = {}
+end
+
 ---
 -- Main entry point called from C during URL parsing
 -- @param url_text rspamd_text - URL string as text object
@@ -48,50 +54,62 @@ function exports.filter_url_string(url_text, flags)
     return exports.REJECT -- Overly long URL
   end
 
-  -- Convert to string for pattern matching
-  -- This is acceptable since we're called rarely (only on suspicious patterns)
-  local url_str = url_text:str()
-
-  -- Check for control characters (0x00-0x1F except tab/newline, and 0x7F)
-  -- Using string.find with byte patterns
-  for i = 0, 31 do
-    if i ~= 9 and i ~= 10 then -- Allow tab (\t) and newline (\n)
-      if url_str:find(string.char(i), 1, true) then
-        return exports.REJECT -- Control character found
-      end
-    end
-  end
-  if url_str:find(string.char(127), 1, true) then -- DEL
-    return exports.REJECT
+  -- Build control character set: 0x00-0x08, 0x0B-0x1F, 0x7F
+  -- (excluding \t=0x09 and \n=0x0A)
+  local control_chars = "\000\001\002\003\004\005\006\007\008" .. -- 0x00-0x08
+      "\011\012\013\014\015\016\017\018\019\020" .. -- 0x0B-0x14
+      "\021\022\023\024\025\026\027\028\029\030\031" .. -- 0x15-0x1F
+      "\127" -- 0x7F (DEL)
+
+  -- Check for control characters using memcspn
+  local span = url_text:memcspn(control_chars)
+  if span < url_len then
+    return exports.REJECT -- Control character found
   end
 
-  -- UTF-8 validation using rspamd_util
-  if not rspamd_util.is_valid_utf8(url_str) then
+  -- UTF-8 validation (rspamd_util.is_valid_utf8 accepts both text and string)
+  if not rspamd_util.is_valid_utf8(url_text) then
     return exports.REJECT -- Invalid UTF-8
   end
 
-  -- Count @ signs for suspicious patterns
-  local _, at_count = url_str:gsub("@", "")
-  if at_count > 20 then
-    return exports.REJECT -- Way too many @ signs
+  -- Count @ signs and check user field using rspamd_text methods only
+  local at_count = 0
+  local first_at_pos = nil
+  local search_from = 1
+
+  -- Count @ signs using memchr
+  while search_from <= url_len do
+    local substr = url_text:sub(search_from)
+    local found = substr:memchr(string.byte('@'), false)
+
+    if not found or found == -1 then
+      break
+    end
+
+    at_count = at_count + 1
+    -- Adjust found position to be relative to start of url_text
+    local absolute_pos = search_from + found - 1
+    if at_count == 1 then
+      first_at_pos = absolute_pos
+    end
+    search_from = absolute_pos + 1 -- Move past the @ we just found
+
+    if at_count > 20 then
+      return exports.REJECT -- Way too many @ signs
+    end
   end
 
   -- Check user field length (if @ present)
-  if at_count > 0 then
-    -- Find first @
-    local first_at = url_str:find("@", 1, true)
-    if first_at then
-      -- Check what comes before it (could be schema://user@host)
-      -- Look for :// to find start of user field
-      local schema_end = url_str:find("://", 1, true)
-      local user_start = schema_end and (schema_end + 3) or 1
-      local user_len = first_at - user_start
-
-      if user_len > 512 then
-        return exports.REJECT -- Extremely long user field
-      elseif user_len > 128 then
-        return exports.SUSPICIOUS -- Long user field, mark for inspection
-      end
+  if first_at_pos then
+    -- Find :// to determine start of user field
+    local schema_pos = url_text:find("://")
+    local user_start = schema_pos and (schema_pos + 3) or 1
+    local user_len = first_at_pos - user_start
+
+    if user_len > 512 then
+      return exports.REJECT -- Extremely long user field
+    elseif user_len > 64 then
+      return exports.SUSPICIOUS -- Long user field, mark for inspection
     end
 
     -- Multiple @ signs is suspicious
@@ -122,18 +140,14 @@ function exports.filter_url(url)
     return exports.ACCEPT
   end
 
-  -- Get URL as text
-  local url_text = url:get_text()
+  -- Get URL as rspamd_text (pass true to get_text)
+  local url_text = url:get_text(true)
   if not url_text then
     return exports.ACCEPT
   end
 
-  -- Get flags from URL object
-  local flags = 0
-  local url_table = url:to_table()
-  if url_table and url_table.flags then
-    flags = url_table.flags
-  end
+  -- Get flags directly from URL object (no table conversion)
+  local flags = url:get_flags_num() or 0
 
   return exports.filter_url_string(url_text, flags)
 end
index da145305079fee09b3c0cf8ef1ed7759b6f26c36..c3af21f9f2b8b65ff2a3e240793fc50c535a406f 100644 (file)
@@ -162,7 +162,7 @@ rspamd_email_address_parse_heuristic(const char *data, size_t len,
 
        if (*p == '<' && len > 1) {
                /* Angled address */
-               addr->addr_len = rspamd_memcspn(p + 1, ">", len - 1);
+               addr->addr_len = rspamd_memcspn(p + 1, len - 1, ">", 1);
                addr->addr = p + 1;
                addr->raw = p;
                addr->raw_len = len;
index c9e11649aa30c606a2c2c45b41a7e80bee9e3db0..6e7e435a4e4f5a9d487a630671e4eec1f5ff6b1c 100644 (file)
@@ -265,7 +265,7 @@ rspamd_composite_expr_parse(const char *line, gsize len,
 
                switch (state) {
                case comp_state_read_symbol:
-                       clen = rspamd_memcspn(p, "[; \t()><!|&\n", len);
+                       clen = rspamd_memcspn(p, len, "[; \t()><!|&\n", 12);
                        p += clen;
 
                        if (*p == '[') {
@@ -362,7 +362,7 @@ rspamd_composite_expr_parse(const char *line, gsize len,
 
                switch (state) {
                case comp_state_read_symbol: {
-                       clen = rspamd_memcspn(p, "[; \t()><!|&\n", len);
+                       clen = rspamd_memcspn(p, len, "[; \t()><!|&\n", 12);
                        p += clen;
 
                        if (*p == '[') {
index 600b7f1e15251f69a24d1d8b37c153d68862ce4e..37741743cb69d57bb196ef6d28916e56ff8d59f5 100644 (file)
@@ -1198,7 +1198,7 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx,
                iov_ctx->iov[0].iov_base = tmpbuf;
                iov_ctx->iov[0].iov_len = r;
                /* TODO: is it possible to have other 'bad' symbols here? */
-               if (rspamd_memcspn(message, "\"\\\r\n\b\t\v", mlen) == mlen) {
+               if (rspamd_memcspn(message, mlen, "\"\\\r\n\b\t\v", 6) == mlen) {
                        iov_ctx->iov[1].iov_base = (void *) message;
                        iov_ctx->iov[1].iov_len = mlen;
                }
index ba46df0859eb30e6fe11f1d7119cd1d4208ece62..a4fe04c940e1eac77cd732e3a6c6d65d9d959b4f 100644 (file)
@@ -102,7 +102,7 @@ bool rspamd_log_syslog_log(const char *module, const char *id,
 
        if (log_json) {
                long now = rspamd_get_calendar_ticks();
-               if (rspamd_memcspn(message, "\"\\\r\n\b\t\v", mlen) == mlen) {
+               if (rspamd_memcspn(message, mlen, "\"\\\r\n\b\t\v", 6) == mlen) {
                        /* Fast path */
                        syslog(syslog_level, "{\"ts\": %ld, "
                                                                 "\"pid\": %d, "
index 040298bedaee0ab6ca3c0166ddf4626473146da3..eaaeaccfe060c49a0e3646a6d1dbbf0724684edb 100644 (file)
@@ -2810,13 +2810,13 @@ rspamd_decode_uue_buf(const char *in, gsize inlen,
                p += sizeof("begin ") - 1;
                remain -= sizeof("begin ") - 1;
 
-               pos = rspamd_memcspn(p, nline, remain);
+               pos = rspamd_memcspn(p, remain, nline, strlen(nline));
        }
        else if (memcmp(p, "begin-base64 ", sizeof("begin-base64 ") - 1) == 0) {
                base64 = TRUE;
                p += sizeof("begin-base64 ") - 1;
                remain -= sizeof("begin-base64 ") - 1;
-               pos = rspamd_memcspn(p, nline, remain);
+               pos = rspamd_memcspn(p, remain, nline, strlen(nline));
        }
        else {
                /* Crap */
@@ -2857,7 +2857,7 @@ rspamd_decode_uue_buf(const char *in, gsize inlen,
                const char *eol;
                int i, ch;
 
-               pos = rspamd_memcspn(p, nline, remain);
+               pos = rspamd_memcspn(p, remain, nline, strlen(nline));
 
                if (pos == 0) {
                        /* Skip empty lines */
@@ -2936,20 +2936,24 @@ rspamd_decode_uue_buf(const char *in, gsize inlen,
        ((a)[(gsize) (b) / (8 * sizeof *(a))] op(gsize) 1 << ((gsize) (b) % (8 * sizeof *(a))))
 
 
-gsize rspamd_memcspn(const char *s, const char *e, gsize len)
+gsize rspamd_memcspn(const void *data, gsize dlen, const void *reject, gsize rlen)
 {
        gsize byteset[32 / sizeof(gsize)];
-       const char *p = s, *end = s + len;
+       const unsigned char *s = (const unsigned char *) data;
+       const unsigned char *r = (const unsigned char *) reject;
+       const unsigned char *p = s, *end = s + dlen;
 
-       if (!e[1]) {
-               for (; p < end && *p != *e; p++);
-               return p - s;
-       }
+       memset(byteset, 0, sizeof(byteset));
 
-       memset(byteset, 0, sizeof byteset);
+       /* Build bitset from reject set */
+       for (gsize i = 0; i < rlen; i++) {
+               BITOP(byteset, r[i], |=);
+       }
 
-       for (; *e && BITOP(byteset, *(unsigned char *) e, |=); e++);
-       for (; p < end && !BITOP(byteset, *(unsigned char *) p, &); p++);
+       /* Scan for first character in reject set */
+       while (p < end && !BITOP(byteset, *p, &)) {
+               p++;
+       }
 
        return p - s;
 }
@@ -3044,7 +3048,7 @@ rspamd_decode_qp2047_buf(const char *in, gsize inlen,
                }
                else {
                        if (end - o >= remain) {
-                               processed = rspamd_memcspn(p, "=_", remain);
+                               processed = rspamd_memcspn(p, remain, "=_", 2);
                                memcpy(o, p, processed);
                                o += processed;
 
@@ -3764,8 +3768,9 @@ rspamd_string_len_split(const char *in, gsize len, const char *spill,
        char **res;
 
        /* Detect number of elements */
+       gsize spill_len = strlen(spill);
        while (p < end) {
-               gsize cur_fragment = rspamd_memcspn(p, spill, end - p);
+               gsize cur_fragment = rspamd_memcspn(p, end - p, spill, spill_len);
 
                if (cur_fragment > 0) {
                        detected_elts++;
@@ -3787,7 +3792,7 @@ rspamd_string_len_split(const char *in, gsize len, const char *spill,
        p = in;
 
        while (p < end) {
-               gsize cur_fragment = rspamd_memcspn(p, spill, end - p);
+               gsize cur_fragment = rspamd_memcspn(p, end - p, spill, spill_len);
 
                if (cur_fragment > 0) {
                        char *elt;
index 5d59785f36af34f2a775542da1d722800a76daae..e0324a126455e464dadbf0ab8b3866f5ceaf12b3 100644 (file)
@@ -451,13 +451,14 @@ void *rspamd_memrchr(const void *m, int c, gsize len);
 #endif
 
 /**
- * Return length of memory segment starting in `s` that contains no chars from `e`
- * @param s any input
- * @param e zero terminated string of exceptions
- * @param len length of `s`
- * @return segment size
- */
-gsize rspamd_memcspn(const char *s, const char *e, gsize len);
+ * Return length of memory segment starting in `data` that contains no bytes from `reject`
+ * @param data input data
+ * @param dlen length of data
+ * @param reject set of bytes to reject (can contain binary data including nulls)
+ * @param rlen length of reject set
+ * @return length of initial segment with no rejected bytes
+ */
+gsize rspamd_memcspn(const void *data, gsize dlen, const void *reject, gsize rlen);
 
 /**
  * Return length of memory segment starting in `s` that contains only chars from `e`
index acbc05736efc158c349efda276b158cad1c9629d..04d645a9c8aac4c31a1b8eab4c351efe33104095 100644 (file)
@@ -1526,7 +1526,7 @@ rspamd_upstreams_parse_line_len(struct upstream_list *ups,
        }
 
        while (p < end) {
-               span_len = rspamd_memcspn(p, separators, end - p);
+               span_len = rspamd_memcspn(p, end - p, separators, strlen(separators));
 
                if (span_len > 0) {
                        tmp = g_malloc(span_len + 1);
index b45ee1743da571cb80839b723e884829fcc1d1f8..4a62bdfe9874fe90986e33fa4cd17d755adb04ea 100644 (file)
@@ -180,6 +180,14 @@ LUA_FUNCTION_DEF(text, take_ownership);
  * @return {rspamd_text} modified or copied text
  */
 LUA_FUNCTION_DEF(text, exclude_chars);
+/***
+ * @method rspamd_text:memcspn(reject_chars)
+ * Returns the length of the initial segment of text that consists entirely
+ * of characters NOT in reject_chars (like C's strcspn, but for binary data)
+ * @param {string} reject_chars set of characters to reject
+ * @return {integer} length of the initial segment
+ */
+LUA_FUNCTION_DEF(text, memcspn);
 /***
  * @method rspamd_text:oneline([always_copy])
  * Returns a text (if owned, then the original text is modified, if not, then it is copied and owned)
@@ -266,6 +274,7 @@ static const struct luaL_reg textlib_m[] = {
        LUA_INTERFACE_DEF(text, bytes),
        LUA_INTERFACE_DEF(text, lower),
        LUA_INTERFACE_DEF(text, exclude_chars),
+       LUA_INTERFACE_DEF(text, memcspn),
        LUA_INTERFACE_DEF(text, oneline),
        LUA_INTERFACE_DEF(text, base32),
        LUA_INTERFACE_DEF(text, base64),
@@ -1575,6 +1584,32 @@ lua_text_exclude_chars(lua_State *L)
        return 1;
 }
 
+static int
+lua_text_memcspn(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_lua_text *t = lua_check_text(L, 1);
+       const char *reject_chars;
+       gsize reject_len, span_len;
+
+       if (t == NULL) {
+               return luaL_error(L, "invalid arguments");
+       }
+
+       reject_chars = luaL_checklstring(L, 2, &reject_len);
+       if (reject_chars == NULL || reject_len == 0) {
+               /* No reject chars - return full length */
+               lua_pushinteger(L, t->len);
+               return 1;
+       }
+
+       /* Use rspamd_memcspn from str_util */
+       span_len = rspamd_memcspn(t->start, t->len, reject_chars, reject_len);
+       lua_pushinteger(L, span_len);
+
+       return 1;
+}
+
 static int
 lua_text_oneline(lua_State *L)
 {
index b6c42b0c7b94b43876b3fd3a30fd6c6e8712167f..f23a6323734dde6be9a824775edd60d9fef20c44 100644 (file)
@@ -290,18 +290,34 @@ lua_url_get_fragment(lua_State *L)
 }
 
 /***
- * @method url:get_text()
+ * @method url:get_text([as_text])
  * Get full content of the url
- * @return {string} url string
+ * @param {boolean} as_text if true, return as rspamd_text, otherwise as string
+ * @return {string|rspamd_text} url string or text object
  */
 static int
 lua_url_get_text(lua_State *L)
 {
        LUA_TRACE_POINT;
        struct rspamd_lua_url *url = lua_check_url(L, 1);
+       gboolean as_text = FALSE;
+
+       if (lua_isboolean(L, 2)) {
+               as_text = lua_toboolean(L, 2);
+       }
 
        if (url != NULL) {
-               lua_pushlstring(L, url->url->string, url->url->urllen);
+               if (as_text) {
+                       struct rspamd_lua_text *t;
+                       t = lua_newuserdata(L, sizeof(*t));
+                       rspamd_lua_setclass(L, rspamd_text_classname, -1);
+                       t->start = url->url->string;
+                       t->len = url->url->urllen;
+                       t->flags = 0; /* Read-only, not owned */
+               }
+               else {
+                       lua_pushlstring(L, url->url->string, url->url->urllen);
+               }
        }
        else {
                lua_pushnil(L);
index b8a646e808e7812b87a0a1374fab68c3df83f964..d56ad81ed0c7a19a0501fe413a93e00167d60adb 100644 (file)
@@ -6,6 +6,7 @@ context("URL filter functions", function()
   local mpool = require("rspamd_mempool")
   local test_helper = require("rspamd_test_helper")
   local logger = require("rspamd_logger")
+  local rspamd_text = require("rspamd_text")
 
   test_helper.init_url_parser()
 
@@ -42,7 +43,8 @@ context("URL filter functions", function()
 
   for i, c in ipairs(filter_cases) do
     test("filter_url_string: " .. c[4], function()
-      local result = lua_url_filter.filter_url_string(c[1], c[2])
+      local url_text = rspamd_text.fromstring(c[1])
+      local result = lua_url_filter.filter_url_string(url_text, c[2])
       assert_equal(c[3], result,
           logger.slog('expected result %s, but got %s for "%s"',
               c[3], result, c[4]))
@@ -77,7 +79,8 @@ context("URL filter functions", function()
 
   for i, c in ipairs(utf8_cases) do
     test("UTF-8 validation: " .. c[3], function()
-      local result = lua_url_filter.filter_url_string(c[1], 0)
+      local url_text = rspamd_text.fromstring(c[1])
+      local result = lua_url_filter.filter_url_string(url_text, 0)
       assert_equal(c[2], result,
           logger.slog('expected result %s, but got %s for "%s"',
               c[2], result, c[3]))
@@ -86,10 +89,13 @@ context("URL filter functions", function()
 
   -- Test custom filter registration
   test("register custom filter", function()
+    lua_url_filter.clear_filters() -- Clear any previously registered filters
+
     local called = false
-    local custom_filter = function(url_str, flags)
+    local custom_filter = function(url_text, flags)
       called = true
-      if url_str:match("blocked") then
+      -- Custom filters receive rspamd_text, use :find instead of :match
+      if url_text:find("blocked") then
         return REJECT
       end
       return ACCEPT
@@ -97,13 +103,18 @@ context("URL filter functions", function()
 
     lua_url_filter.register_filter(custom_filter)
 
-    local result = lua_url_filter.filter_url_string("http://blocked.example.com", 0)
+    local url_text = rspamd_text.fromstring("http://blocked.example.com")
+    local result = lua_url_filter.filter_url_string(url_text, 0)
     assert_true(called, "custom filter was not called")
     assert_equal(REJECT, result, "custom filter did not reject")
+
+    lua_url_filter.clear_filters() -- Clean up after test
   end)
 
   -- Test filter chaining
   test("filter chaining stops on REJECT", function()
+    lua_url_filter.clear_filters() -- Clear any previously registered filters
+
     local filter1_called = false
     local filter2_called = false
 
@@ -117,10 +128,13 @@ context("URL filter functions", function()
       return ACCEPT
     end)
 
-    lua_url_filter.filter_url_string("http://example.com", 0)
+    local url_text = rspamd_text.fromstring("http://example.com")
+    lua_url_filter.filter_url_string(url_text, 0)
 
     assert_true(filter1_called, "first filter not called")
     assert_false(filter2_called, "second filter called despite REJECT")
+
+    lua_url_filter.clear_filters() -- Clean up after test
   end)
 
   -- Test oversized user field (issue #5731)
@@ -128,7 +142,8 @@ context("URL filter functions", function()
     local long_user = string.rep("a", 80)
     local url_str = "http://" .. long_user .. ":password@example.com/path"
 
-    local result = lua_url_filter.filter_url_string(url_str, 0)
+    local url_text = rspamd_text.fromstring(url_str)
+    local result = lua_url_filter.filter_url_string(url_text, 0)
 
     -- Should be SUSPICIOUS, not REJECT, allowing C parser to continue
     assert_equal(SUSPICIOUS, result,