]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] lua_feedback_parsers: reuse lua_util helpers
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 13 Apr 2026 12:27:55 +0000 (13:27 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 13 Apr 2026 12:27:55 +0000 (13:27 +0100)
Replace local trim/split_lines/normalize_eol with lua_util.str_trim and
lua_util.rspamd_str_split. Drop the safe_get_type_full/safe_get_content
pcall wrappers since the mime_part C API does not throw, and remove the
internal _-prefixed exports that nothing consumes.

lualib/lua_feedback_parsers.lua

index 948a50a21b73c291e3579515aa26737a5403abd7..e11f2a0ddf1c35053bac05a7e73ea9f0981158e3 100644 (file)
@@ -30,23 +30,14 @@ limitations under the License.
 --]]
 
 local rspamd_logger = require 'rspamd_logger'
+local lua_util = require 'lua_util'
 
 local N = 'lua_feedback_parsers'
+local str_trim = lua_util.str_trim
+local str_split = lua_util.rspamd_str_split
 
 local exports = {}
 
--- ----------------------------------------------------------------------------
--- Generic helpers
--- ----------------------------------------------------------------------------
-
--- Trim ASCII whitespace from both ends of a string.
-local function trim(s)
-  if not s then
-    return nil
-  end
-  return (s:gsub('^%s+', ''):gsub('%s+$', ''))
-end
-
 -- Strip a single pair of outermost angle brackets, e.g. `<id@example>`.
 local function strip_angles(s)
   if not s then
@@ -56,36 +47,7 @@ local function strip_angles(s)
   if inner then
     return inner
   end
-  return trim(s)
-end
-
--- Normalise line endings: drop CR, keep LF.
-local function normalize_eol(body)
-  if type(body) ~= 'string' then
-    body = tostring(body or '')
-  end
-  return (body:gsub('\r', ''))
-end
-
--- Split a string on LF into a plain array of lines (no trailing empty line
--- duplication). Preserves empty lines in the middle.
-local function split_lines(body)
-  local lines = {}
-  local i = 1
-  local len = #body
-  local start = 1
-  while i <= len do
-    local c = body:sub(i, i)
-    if c == '\n' then
-      lines[#lines + 1] = body:sub(start, i - 1)
-      start = i + 1
-    end
-    i = i + 1
-  end
-  if start <= len then
-    lines[#lines + 1] = body:sub(start)
-  end
-  return lines
+  return str_trim(s)
 end
 
 --[[
@@ -112,7 +74,7 @@ local function parse_field_block(lines, start_line)
 
   local function flush()
     if current_name then
-      local value = trim(table.concat(current_value_parts, ' '))
+      local value = str_trim(table.concat(current_value_parts, ' '))
       fields[current_name] = value
       local list = fields_multi[current_name]
       if not list then
@@ -129,13 +91,12 @@ local function parse_field_block(lines, start_line)
     local line = lines[i]
     if line == '' then
       flush()
-      i = i + 1
-      return fields, fields_multi, i
+      return fields, fields_multi, i + 1
     end
     local first = line:sub(1, 1)
     if first == ' ' or first == '\t' then
       if current_name then
-        current_value_parts[#current_value_parts + 1] = trim(line)
+        current_value_parts[#current_value_parts + 1] = str_trim(line)
       end
       -- else: continuation with no preceding field - ignore
     else
@@ -158,8 +119,14 @@ end
 -- lines. Used for message/delivery-status bodies which consist of 1..N
 -- blocks.
 local function parse_field_blocks(body)
-  local lines = split_lines(normalize_eol(body))
-  -- Skip leading blank lines.
+  if type(body) ~= 'string' then
+    body = tostring(body or '')
+  end
+  -- Normalise line endings (drop CR) then split on LF.
+  local lines = str_split(body:gsub('\r', ''), '\n')
+  if not lines then
+    return {}
+  end
   local i = 1
   while i <= #lines and lines[i] == '' do
     i = i + 1
@@ -167,7 +134,6 @@ local function parse_field_blocks(body)
   local blocks = {}
   while i <= #lines do
     local fields, fields_multi, next_i = parse_field_block(lines, i)
-    -- Only keep non-empty blocks.
     if next(fields) ~= nil then
       blocks[#blocks + 1] = {
         fields = fields,
@@ -178,7 +144,6 @@ local function parse_field_blocks(body)
       break
     end
     i = next_i
-    -- Skip further blank lines between blocks.
     while i <= #lines and lines[i] == '' do
       i = i + 1
     end
@@ -186,35 +151,13 @@ local function parse_field_blocks(body)
   return blocks
 end
 
--- Safely fetch `(type, subtype, params)` for a mime part, falling back to
--- nil on any error. `get_type_full` returns a 3-value tuple in recent
--- Rspamd; older versions may only return type/subtype. We use pcall to be
--- safe regardless.
-local function safe_get_type_full(part)
-  local ok, t, st, params = pcall(part.get_type_full, part)
-  if not ok then
-    return nil, nil, nil
-  end
-  return t, st, params
-end
-
--- Safely fetch plain content as a Lua string.
-local function safe_get_content(part)
-  local ok, content = pcall(part.get_content, part)
-  if not ok or content == nil then
-    return nil
-  end
-  return tostring(content)
-end
-
 -- Find the topmost multipart/report part in a task that matches the given
 -- `report-type` (case-insensitive). Returns the matching mime_part or nil.
 local function find_multipart_report(task, wanted_report_type)
-  local parts = task:get_parts() or {}
-  for _, part in ipairs(parts) do
-    local t, st, params = safe_get_type_full(part)
+  for _, part in ipairs(task:get_parts() or {}) do
+    local t, st, params = part:get_type_full()
     if t == 'multipart' and st == 'report' and type(params) == 'table' then
-      local rt = params['report-type'] or params['Report-Type']
+      local rt = params['report-type']
       if rt and rt:lower() == wanted_report_type then
         return part
       end
@@ -225,11 +168,9 @@ end
 
 -- Find the first sub-part whose Content-Type matches `wanted_type/wanted_subtype`
 -- (case-insensitive). If `wanted_subtype` is nil, only `wanted_type` is
--- matched. Searches all parts on the task (they are already flattened by
--- the mime parser).
+-- matched.
 local function find_part_by_type(task, wanted_type, wanted_subtype)
-  local parts = task:get_parts() or {}
-  for _, part in ipairs(parts) do
+  for _, part in ipairs(task:get_parts() or {}) do
     local t, st = part:get_type()
     if t and t:lower() == wanted_type and
         (not wanted_subtype or (st and st:lower() == wanted_subtype)) then
@@ -239,10 +180,11 @@ local function find_part_by_type(task, wanted_type, wanted_subtype)
   return nil
 end
 
--- Like find_part_by_type but iterates a set of candidate type pairs.
+-- Locate the embedded original message in a report.
+-- Returns (part, kind) where kind is 'full' for message/rfc822|message/global
+-- (headers+body) and 'headers' for text/rfc822-headers (headers only).
 local function find_original_message_part(task)
-  local parts = task:get_parts() or {}
-  for _, part in ipairs(parts) do
+  for _, part in ipairs(task:get_parts() or {}) do
     local t, st = part:get_type()
     if t and st then
       local lt = t:lower()
@@ -258,32 +200,22 @@ local function find_original_message_part(task)
   return nil
 end
 
--- Given a string containing full RFC 822 content (headers + optional body)
--- return a table mapping lowercased header names to values. If the string
--- only contains headers (no trailing blank line), the parser still works.
-local function parse_rfc822_headers(content)
-  if not content or content == '' then
-    return {}
-  end
-  local lines = split_lines(normalize_eol(content))
-  local fields = parse_field_block(lines, 1)
-  return fields or {}
-end
-
--- Extract the standard subset of original-message headers we care about.
-local function extract_original_message(part, kind)
-  local content = safe_get_content(part)
+-- Extract the standard subset of original-message headers we care about from
+-- the content of a message/rfc822 (or text/rfc822-headers) sub-part.
+local function extract_original_message(part)
+  local content = part:get_content()
   if not content then
     return nil
   end
-  local headers
-  if kind == 'headers' then
-    headers = parse_rfc822_headers(content)
-  else
-    -- For message/rfc822 (or message/global) content is the full embedded
-    -- message; the header block is everything up to the first blank line.
-    headers = parse_rfc822_headers(content)
+  content = tostring(content)
+  if content == '' then
+    return nil
   end
+  local lines = str_split(content:gsub('\r', ''), '\n')
+  if not lines then
+    return nil
+  end
+  local headers = parse_field_block(lines, 1)
   if not headers or next(headers) == nil then
     return nil
   end
@@ -294,7 +226,6 @@ local function extract_original_message(part, kind)
     subject = headers['subject'],
     date = headers['date'],
   }
-  -- If every field came back nil, treat as no useful data.
   if not (out.message_id or out.from or out.to or out.subject or out.date) then
     return nil
   end
@@ -317,7 +248,7 @@ end
 -- parsed into at least one non-empty field block, the function still
 -- returns a table (with `recipients = {}`) so that callers can distinguish
 -- "not a DSN" (nil) from "a DSN we couldn't fully parse" (table with
--- mostly-nil fields). Exceptions from the C API are caught and silenced.
+-- mostly-nil fields).
 --
 -- @param {rspamd_task} task message to inspect
 -- @return {table|nil} parsed DSN, see module doc for the shape
@@ -345,9 +276,9 @@ function exports.parse_dsn(task)
   }
 
   if status_part then
-    local body = safe_get_content(status_part)
+    local body = status_part:get_content()
     if body then
-      local blocks = parse_field_blocks(body)
+      local blocks = parse_field_blocks(tostring(body))
       if #blocks > 0 then
         local per_message = blocks[1].fields
         result.reporting_mta = per_message['reporting-mta']
@@ -374,9 +305,9 @@ function exports.parse_dsn(task)
     end
   end
 
-  local orig_part, kind = find_original_message_part(task)
+  local orig_part = find_original_message_part(task)
   if orig_part then
-    result.original_message = extract_original_message(orig_part, kind)
+    result.original_message = extract_original_message(orig_part)
   end
 
   return result
@@ -431,9 +362,9 @@ function exports.parse_arf(task)
     original_message = nil,
   }
 
-  local body = safe_get_content(fb_part)
+  local body = fb_part:get_content()
   if body then
-    local blocks = parse_field_blocks(body)
+    local blocks = parse_field_blocks(tostring(body))
     if #blocks > 0 then
       local f = blocks[1].fields
       local fm = blocks[1].fields_multi
@@ -465,9 +396,9 @@ function exports.parse_arf(task)
     rspamd_logger.debugm(N, task, 'ARF detected but feedback-report part content is empty')
   end
 
-  local orig_part, kind = find_original_message_part(task)
+  local orig_part = find_original_message_part(task)
   if orig_part then
-    local om = extract_original_message(orig_part, kind)
+    local om = extract_original_message(orig_part)
     if om then
       -- RFC 5965 consumers typically only care about Message-ID and From.
       result.original_message = {
@@ -480,9 +411,4 @@ function exports.parse_arf(task)
   return result
 end
 
--- Exposed for tests / introspection.
-exports._parse_field_blocks = parse_field_blocks
-exports._parse_rfc822_headers = parse_rfc822_headers
-exports._strip_angles = strip_angles
-
 return exports