table.insert(custom_filters, filter_func)
end
+---
+-- Clear all custom filters (mainly for testing)
+function exports.clear_filters()
+ custom_filters = {}
+end
+
---
-- Main entry point called from C during URL parsing
-- @param url_text rspamd_text - URL string as text object
return exports.REJECT -- Overly long URL
end
- -- Convert to string for pattern matching
- -- This is acceptable since we're called rarely (only on suspicious patterns)
- local url_str = url_text:str()
-
- -- Check for control characters (0x00-0x1F except tab/newline, and 0x7F)
- -- Using string.find with byte patterns
- for i = 0, 31 do
- if i ~= 9 and i ~= 10 then -- Allow tab (\t) and newline (\n)
- if url_str:find(string.char(i), 1, true) then
- return exports.REJECT -- Control character found
- end
- end
- end
- if url_str:find(string.char(127), 1, true) then -- DEL
- return exports.REJECT
+ -- Build control character set: 0x00-0x08, 0x0B-0x1F, 0x7F
+ -- (excluding \t=0x09 and \n=0x0A)
+ local control_chars = "\000\001\002\003\004\005\006\007\008" .. -- 0x00-0x08
+ "\011\012\013\014\015\016\017\018\019\020" .. -- 0x0B-0x14
+ "\021\022\023\024\025\026\027\028\029\030\031" .. -- 0x15-0x1F
+ "\127" -- 0x7F (DEL)
+
+ -- Check for control characters using memcspn
+ local span = url_text:memcspn(control_chars)
+ if span < url_len then
+ return exports.REJECT -- Control character found
end
- -- UTF-8 validation using rspamd_util
- if not rspamd_util.is_valid_utf8(url_str) then
+ -- UTF-8 validation (rspamd_util.is_valid_utf8 accepts both text and string)
+ if not rspamd_util.is_valid_utf8(url_text) then
return exports.REJECT -- Invalid UTF-8
end
- -- Count @ signs for suspicious patterns
- local _, at_count = url_str:gsub("@", "")
- if at_count > 20 then
- return exports.REJECT -- Way too many @ signs
+ -- Count @ signs and check user field using rspamd_text methods only
+ local at_count = 0
+ local first_at_pos = nil
+ local search_from = 1
+
+ -- Count @ signs using memchr
+ while search_from <= url_len do
+ local substr = url_text:sub(search_from)
+ local found = substr:memchr(string.byte('@'), false)
+
+ if not found or found == -1 then
+ break
+ end
+
+ at_count = at_count + 1
+ -- Adjust found position to be relative to start of url_text
+ local absolute_pos = search_from + found - 1
+ if at_count == 1 then
+ first_at_pos = absolute_pos
+ end
+ search_from = absolute_pos + 1 -- Move past the @ we just found
+
+ if at_count > 20 then
+ return exports.REJECT -- Way too many @ signs
+ end
end
-- Check user field length (if @ present)
- if at_count > 0 then
- -- Find first @
- local first_at = url_str:find("@", 1, true)
- if first_at then
- -- Check what comes before it (could be schema://user@host)
- -- Look for :// to find start of user field
- local schema_end = url_str:find("://", 1, true)
- local user_start = schema_end and (schema_end + 3) or 1
- local user_len = first_at - user_start
-
- if user_len > 512 then
- return exports.REJECT -- Extremely long user field
- elseif user_len > 128 then
- return exports.SUSPICIOUS -- Long user field, mark for inspection
- end
+ if first_at_pos then
+ -- Find :// to determine start of user field
+ local schema_pos = url_text:find("://")
+ local user_start = schema_pos and (schema_pos + 3) or 1
+ local user_len = first_at_pos - user_start
+
+ if user_len > 512 then
+ return exports.REJECT -- Extremely long user field
+ elseif user_len > 64 then
+ return exports.SUSPICIOUS -- Long user field, mark for inspection
end
-- Multiple @ signs is suspicious
return exports.ACCEPT
end
- -- Get URL as text
- local url_text = url:get_text()
+ -- Get URL as rspamd_text (pass true to get_text)
+ local url_text = url:get_text(true)
if not url_text then
return exports.ACCEPT
end
- -- Get flags from URL object
- local flags = 0
- local url_table = url:to_table()
- if url_table and url_table.flags then
- flags = url_table.flags
- end
+ -- Get flags directly from URL object (no table conversion)
+ local flags = url:get_flags_num() or 0
return exports.filter_url_string(url_text, flags)
end
if (*p == '<' && len > 1) {
/* Angled address */
- addr->addr_len = rspamd_memcspn(p + 1, ">", len - 1);
+ addr->addr_len = rspamd_memcspn(p + 1, len - 1, ">", 1);
addr->addr = p + 1;
addr->raw = p;
addr->raw_len = len;
switch (state) {
case comp_state_read_symbol:
- clen = rspamd_memcspn(p, "[; \t()><!|&\n", len);
+ clen = rspamd_memcspn(p, len, "[; \t()><!|&\n", 12);
p += clen;
if (*p == '[') {
switch (state) {
case comp_state_read_symbol: {
- clen = rspamd_memcspn(p, "[; \t()><!|&\n", len);
+ clen = rspamd_memcspn(p, len, "[; \t()><!|&\n", 12);
p += clen;
if (*p == '[') {
iov_ctx->iov[0].iov_base = tmpbuf;
iov_ctx->iov[0].iov_len = r;
/* TODO: is it possible to have other 'bad' symbols here? */
- if (rspamd_memcspn(message, "\"\\\r\n\b\t\v", mlen) == mlen) {
+ if (rspamd_memcspn(message, mlen, "\"\\\r\n\b\t\v", 6) == mlen) {
iov_ctx->iov[1].iov_base = (void *) message;
iov_ctx->iov[1].iov_len = mlen;
}
if (log_json) {
long now = rspamd_get_calendar_ticks();
- if (rspamd_memcspn(message, "\"\\\r\n\b\t\v", mlen) == mlen) {
+ if (rspamd_memcspn(message, mlen, "\"\\\r\n\b\t\v", 6) == mlen) {
/* Fast path */
syslog(syslog_level, "{\"ts\": %ld, "
"\"pid\": %d, "
p += sizeof("begin ") - 1;
remain -= sizeof("begin ") - 1;
- pos = rspamd_memcspn(p, nline, remain);
+ pos = rspamd_memcspn(p, remain, nline, strlen(nline));
}
else if (memcmp(p, "begin-base64 ", sizeof("begin-base64 ") - 1) == 0) {
base64 = TRUE;
p += sizeof("begin-base64 ") - 1;
remain -= sizeof("begin-base64 ") - 1;
- pos = rspamd_memcspn(p, nline, remain);
+ pos = rspamd_memcspn(p, remain, nline, strlen(nline));
}
else {
/* Crap */
const char *eol;
int i, ch;
- pos = rspamd_memcspn(p, nline, remain);
+ pos = rspamd_memcspn(p, remain, nline, strlen(nline));
if (pos == 0) {
/* Skip empty lines */
((a)[(gsize) (b) / (8 * sizeof *(a))] op(gsize) 1 << ((gsize) (b) % (8 * sizeof *(a))))
-gsize rspamd_memcspn(const char *s, const char *e, gsize len)
+gsize rspamd_memcspn(const void *data, gsize dlen, const void *reject, gsize rlen)
{
gsize byteset[32 / sizeof(gsize)];
- const char *p = s, *end = s + len;
+ const unsigned char *s = (const unsigned char *) data;
+ const unsigned char *r = (const unsigned char *) reject;
+ const unsigned char *p = s, *end = s + dlen;
- if (!e[1]) {
- for (; p < end && *p != *e; p++);
- return p - s;
- }
+ memset(byteset, 0, sizeof(byteset));
- memset(byteset, 0, sizeof byteset);
+ /* Build bitset from reject set */
+ for (gsize i = 0; i < rlen; i++) {
+ BITOP(byteset, r[i], |=);
+ }
- for (; *e && BITOP(byteset, *(unsigned char *) e, |=); e++);
- for (; p < end && !BITOP(byteset, *(unsigned char *) p, &); p++);
+ /* Scan for first character in reject set */
+ while (p < end && !BITOP(byteset, *p, &)) {
+ p++;
+ }
return p - s;
}
}
else {
if (end - o >= remain) {
- processed = rspamd_memcspn(p, "=_", remain);
+ processed = rspamd_memcspn(p, remain, "=_", 2);
memcpy(o, p, processed);
o += processed;
char **res;
/* Detect number of elements */
+ gsize spill_len = strlen(spill);
while (p < end) {
- gsize cur_fragment = rspamd_memcspn(p, spill, end - p);
+ gsize cur_fragment = rspamd_memcspn(p, end - p, spill, spill_len);
if (cur_fragment > 0) {
detected_elts++;
p = in;
while (p < end) {
- gsize cur_fragment = rspamd_memcspn(p, spill, end - p);
+ gsize cur_fragment = rspamd_memcspn(p, end - p, spill, spill_len);
if (cur_fragment > 0) {
char *elt;
#endif
/**
- * Return length of memory segment starting in `s` that contains no chars from `e`
- * @param s any input
- * @param e zero terminated string of exceptions
- * @param len length of `s`
- * @return segment size
- */
-gsize rspamd_memcspn(const char *s, const char *e, gsize len);
+ * Return length of memory segment starting in `data` that contains no bytes from `reject`
+ * @param data input data
+ * @param dlen length of data
+ * @param reject set of bytes to reject (can contain binary data including nulls)
+ * @param rlen length of reject set
+ * @return length of initial segment with no rejected bytes
+ */
+gsize rspamd_memcspn(const void *data, gsize dlen, const void *reject, gsize rlen);
/**
* Return length of memory segment starting in `s` that contains only chars from `e`
}
while (p < end) {
- span_len = rspamd_memcspn(p, separators, end - p);
+ span_len = rspamd_memcspn(p, end - p, separators, strlen(separators));
if (span_len > 0) {
tmp = g_malloc(span_len + 1);
* @return {rspamd_text} modified or copied text
*/
LUA_FUNCTION_DEF(text, exclude_chars);
+/***
+ * @method rspamd_text:memcspn(reject_chars)
+ * Returns the length of the initial segment of text that consists entirely
+ * of characters NOT in reject_chars (like C's strcspn, but for binary data)
+ * @param {string} reject_chars set of characters to reject
+ * @return {integer} length of the initial segment
+ */
+LUA_FUNCTION_DEF(text, memcspn);
/***
* @method rspamd_text:oneline([always_copy])
* Returns a text (if owned, then the original text is modified, if not, then it is copied and owned)
LUA_INTERFACE_DEF(text, bytes),
LUA_INTERFACE_DEF(text, lower),
LUA_INTERFACE_DEF(text, exclude_chars),
+ LUA_INTERFACE_DEF(text, memcspn),
LUA_INTERFACE_DEF(text, oneline),
LUA_INTERFACE_DEF(text, base32),
LUA_INTERFACE_DEF(text, base64),
return 1;
}
+static int
+lua_text_memcspn(lua_State *L)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_text *t = lua_check_text(L, 1);
+ const char *reject_chars;
+ gsize reject_len, span_len;
+
+ if (t == NULL) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ reject_chars = luaL_checklstring(L, 2, &reject_len);
+ if (reject_chars == NULL || reject_len == 0) {
+ /* No reject chars - return full length */
+ lua_pushinteger(L, t->len);
+ return 1;
+ }
+
+ /* Use rspamd_memcspn from str_util */
+ span_len = rspamd_memcspn(t->start, t->len, reject_chars, reject_len);
+ lua_pushinteger(L, span_len);
+
+ return 1;
+}
+
static int
lua_text_oneline(lua_State *L)
{
}
/***
- * @method url:get_text()
+ * @method url:get_text([as_text])
* Get full content of the url
- * @return {string} url string
+ * @param {boolean} as_text if true, return as rspamd_text, otherwise as string
+ * @return {string|rspamd_text} url string or text object
*/
static int
lua_url_get_text(lua_State *L)
{
LUA_TRACE_POINT;
struct rspamd_lua_url *url = lua_check_url(L, 1);
+ gboolean as_text = FALSE;
+
+ if (lua_isboolean(L, 2)) {
+ as_text = lua_toboolean(L, 2);
+ }
if (url != NULL) {
- lua_pushlstring(L, url->url->string, url->url->urllen);
+ if (as_text) {
+ struct rspamd_lua_text *t;
+ t = lua_newuserdata(L, sizeof(*t));
+ rspamd_lua_setclass(L, rspamd_text_classname, -1);
+ t->start = url->url->string;
+ t->len = url->url->urllen;
+ t->flags = 0; /* Read-only, not owned */
+ }
+ else {
+ lua_pushlstring(L, url->url->string, url->url->urllen);
+ }
}
else {
lua_pushnil(L);
local mpool = require("rspamd_mempool")
local test_helper = require("rspamd_test_helper")
local logger = require("rspamd_logger")
+ local rspamd_text = require("rspamd_text")
test_helper.init_url_parser()
for i, c in ipairs(filter_cases) do
test("filter_url_string: " .. c[4], function()
- local result = lua_url_filter.filter_url_string(c[1], c[2])
+ local url_text = rspamd_text.fromstring(c[1])
+ local result = lua_url_filter.filter_url_string(url_text, c[2])
assert_equal(c[3], result,
logger.slog('expected result %s, but got %s for "%s"',
c[3], result, c[4]))
for i, c in ipairs(utf8_cases) do
test("UTF-8 validation: " .. c[3], function()
- local result = lua_url_filter.filter_url_string(c[1], 0)
+ local url_text = rspamd_text.fromstring(c[1])
+ local result = lua_url_filter.filter_url_string(url_text, 0)
assert_equal(c[2], result,
logger.slog('expected result %s, but got %s for "%s"',
c[2], result, c[3]))
-- Test custom filter registration
test("register custom filter", function()
+ lua_url_filter.clear_filters() -- Clear any previously registered filters
+
local called = false
- local custom_filter = function(url_str, flags)
+ local custom_filter = function(url_text, flags)
called = true
- if url_str:match("blocked") then
+ -- Custom filters receive rspamd_text, use :find instead of :match
+ if url_text:find("blocked") then
return REJECT
end
return ACCEPT
lua_url_filter.register_filter(custom_filter)
- local result = lua_url_filter.filter_url_string("http://blocked.example.com", 0)
+ local url_text = rspamd_text.fromstring("http://blocked.example.com")
+ local result = lua_url_filter.filter_url_string(url_text, 0)
assert_true(called, "custom filter was not called")
assert_equal(REJECT, result, "custom filter did not reject")
+
+ lua_url_filter.clear_filters() -- Clean up after test
end)
-- Test filter chaining
test("filter chaining stops on REJECT", function()
+ lua_url_filter.clear_filters() -- Clear any previously registered filters
+
local filter1_called = false
local filter2_called = false
return ACCEPT
end)
- lua_url_filter.filter_url_string("http://example.com", 0)
+ local url_text = rspamd_text.fromstring("http://example.com")
+ lua_url_filter.filter_url_string(url_text, 0)
assert_true(filter1_called, "first filter not called")
assert_false(filter2_called, "second filter called despite REJECT")
+
+ lua_url_filter.clear_filters() -- Clean up after test
end)
-- Test oversized user field (issue #5731)
local long_user = string.rep("a", 80)
local url_str = "http://" .. long_user .. ":password@example.com/path"
- local result = lua_url_filter.filter_url_string(url_str, 0)
+ local url_text = rspamd_text.fromstring(url_str)
+ local result = lua_url_filter.filter_url_string(url_text, 0)
-- Should be SUSPICIOUS, not REJECT, allowing C parser to continue
assert_equal(SUSPICIOUS, result,