]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Add defensive checks to PDF parser for malformed input
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 12 Jan 2026 12:17:32 +0000 (12:17 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 12 Jan 2026 12:17:32 +0000 (12:17 +0000)
Add pcall wrappers and type checks throughout pdf.lua to handle
malformed PDFs from untrusted sources without crashing:

- Add nil checks for stream objects before accessing fields
- Wrap grammar matches in pcall to catch parsing errors
- Add type validation before ipairs calls on trie match results
- Wrap span extractions in pcall to handle invalid offsets
- Add defensive checks in processor functions (trailer, suspicious)
- Wrap URL creation in pcall for malformed URI strings

Errors are logged via debugm for diagnosis while allowing
processing to continue gracefully.

lualib/lua_content/pdf.lua

index 70aa82021129d94f312ac331f43cb4487cfc2c91..2e4dbf01ca893e47ae9213533763e060af591702 100644 (file)
@@ -689,19 +689,37 @@ local function maybe_extract_object_stream(obj, pdf, task)
     -- TODO add decryption some day
     return nil
   end
-  local dict = obj.dict
+  if not obj.stream then
+    return nil
+  end
+  -- Defensive checks for stream structure
+  if not obj.stream.data or not obj.stream.len then
+    lua_util.debugm(N, task, 'malformed stream in object %s:%s',
+        obj.major, obj.minor)
+    return nil
+  end
+  local dict = obj.dict or {}
   local len = obj.stream.len
-  local decl_len = maybe_dereference_object(dict.Length, pdf, task)
-
-  if decl_len then
-    local nlen = tonumber(decl_len)
-    if nlen then
-      len = math.min(len, nlen)
+  if dict.Length then
+    local decl_len = maybe_dereference_object(dict.Length, pdf, task)
+    if decl_len then
+      local nlen = tonumber(decl_len)
+      if nlen then
+        len = math.min(len, nlen)
+      end
     end
   end
 
   if len > 0 then
-    local real_stream = obj.stream.data:span(1, len)
+    -- Wrap stream extraction in pcall to handle malformed data
+    local ret, real_stream = pcall(function()
+      return obj.stream.data:span(1, len)
+    end)
+    if not ret or not real_stream then
+      lua_util.debugm(N, task, 'cannot extract stream span from object %s:%s: %s',
+          obj.major, obj.minor, real_stream)
+      return nil
+    end
 
     local uncompressed, filter_err = maybe_apply_filter(dict, real_stream, pdf, task)
 
@@ -928,8 +946,8 @@ process_dict = function(task, pdf, obj, dict)
     end
 
     if not obj.type then
-
-      if obj.dict.S and obj.dict.JS then
+      -- Defensive: check obj.dict exists before accessing its fields
+      if obj.dict and type(obj.dict) == 'table' and obj.dict.S and obj.dict.JS then
         obj.type = 'Javascript'
         lua_util.debugm(N, task, 'implicit type for JavaScript object %s:%s',
             obj.major, obj.minor)
@@ -1070,38 +1088,61 @@ local function pdf_compound_object_unpack(_, uncompressed, pdf, task, first)
   -- First, we need to parse data line by line likely to find a line
   -- that consists of pairs of numbers
   compound_obj_grammar_gen()
-  local elts = compound_obj_grammar:match(uncompressed)
-  if elts and #elts > 0 then
+  -- Wrap grammar match in pcall for safety
+  local match_ok, elts = pcall(compound_obj_grammar.match, compound_obj_grammar, uncompressed)
+  if not match_ok then
+    lua_util.debugm(N, task, 'compound object grammar match failed: %s', elts)
+    return
+  end
+  if elts and type(elts) == 'table' and #elts > 0 then
     lua_util.debugm(N, task, 'compound elts (chunk length %s): %s',
         #uncompressed, elts)
 
     for i, pair in ipairs(elts) do
-      local obj_number, offset = pair[1], pair[2]
-
-      offset = offset + first
-      if offset < #uncompressed then
-        local span_len
-        if i == #elts then
-          span_len = #uncompressed - offset
-        else
-          span_len = (elts[i + 1][2] + first) - offset
-        end
+      -- Defensive: check pair is a valid table
+      if type(pair) ~= 'table' or not pair[1] or not pair[2] then
+        lua_util.debugm(N, task, 'invalid pair in compound object at index %s', i)
+      else
+        local obj_number, offset = pair[1], pair[2]
 
-        if span_len > 0 and offset + span_len <= #uncompressed then
-          local obj = {
-            major = obj_number,
-            minor = 0, -- Implicit
-            data = uncompressed:span(offset + 1, span_len),
-            ref = obj_ref(obj_number, 0)
-          }
-          parse_object_grammar(obj, task, pdf)
+        offset = offset + first
+        if offset < #uncompressed then
+          local span_len
+          if i == #elts then
+            span_len = #uncompressed - offset
+          else
+            local next_pair = elts[i + 1]
+            if type(next_pair) == 'table' and next_pair[2] then
+              span_len = (next_pair[2] + first) - offset
+            else
+              span_len = #uncompressed - offset
+            end
+          end
 
-          if obj.dict then
-            pdf.objects[#pdf.objects + 1] = obj
+          if span_len > 0 and offset + span_len <= #uncompressed then
+            -- Wrap span extraction in pcall
+            local span_ok, span_data = pcall(function()
+              return uncompressed:span(offset + 1, span_len)
+            end)
+            if not span_ok or not span_data then
+              lua_util.debugm(N, task, 'cannot extract span for compound object %s', obj_number)
+            else
+              local obj = {
+                major = obj_number,
+                minor = 0, -- Implicit
+                data = span_data,
+                ref = obj_ref(obj_number, 0)
+              }
+              parse_object_grammar(obj, task, pdf)
+
+              if obj.dict then
+                pdf.objects[#pdf.objects + 1] = obj
+              end
+            end
+          else
+            lua_util.debugm(N, task, 'invalid span_len for compound object %s:%s; offset = %s, len = %s',
+                pair[1], pair[2], offset + span_len, #uncompressed)
           end
-        else
-          lua_util.debugm(N, task, 'invalid span_len for compound object %s:%s; offset = %s, len = %s',
-              pair[1], pair[2], offset + span_len, #uncompressed)
         end
       end
     end
@@ -1361,14 +1402,28 @@ local function search_text(task, pdf, mpart)
       local text = {}
       for _, tobj in ipairs(obj.contents) do
         maybe_extract_object_stream(tobj, pdf, task)
-        local matches = pdf_text_trie:match(tobj.uncompressed or '')
-        if matches then
+        -- Defensive: ensure uncompressed data is usable
+        local uncompressed = tobj.uncompressed
+        if not uncompressed then
+          uncompressed = ''
+        end
+        -- Wrap trie match in pcall to handle unexpected input
+        local match_ok, matches = pcall(pdf_text_trie.match, pdf_text_trie, uncompressed)
+        if not match_ok then
+          lua_util.debugm(N, task, 'trie match failed for object %s:%s: %s',
+              tobj.major, tobj.minor, matches)
+          matches = nil
+        end
+        if matches and type(matches) == 'table' then
           local text_blocks = {}
           local starts = {}
           local ends = {}
 
           for npat, matched_positions in pairs(matches) do
-            if npat == 1 then
+            if type(matched_positions) ~= 'table' then
+              -- Skip malformed match results
+              lua_util.debugm(N, task, 'skipping malformed trie match result: %s', type(matched_positions))
+            elseif npat == 1 then
               for _, pos in ipairs(matched_positions) do
                 starts[#starts + 1] = pos
               end
@@ -1384,39 +1439,49 @@ local function search_text(task, pdf, mpart)
 
           offsets_to_blocks(starts, ends, text_blocks)
           for _, bl in ipairs(text_blocks) do
-            if bl.len > 2 then
+            if bl.len and bl.len > 2 then
               -- To remove \s+ET\b pattern (it can leave trailing space or not but it doesn't matter)
               bl.len = bl.len - 2
             end
-
-            bl.data = tobj.uncompressed:span(bl.start, bl.len)
-            if bl.len <= 256 then
-              lua_util.debugm(N, task, 'extracted text from object %s:%s: %s',
-                  tobj.major, tobj.minor, bl.data)
+            -- Defensive: wrap span extraction in pcall
+            local span_ok, span_data = pcall(function()
+              if type(uncompressed) == 'userdata' or type(uncompressed) == 'string' then
+                return uncompressed:span(bl.start, bl.len)
+              end
+              return nil
+            end)
+            if not span_ok or not span_data then
+              lua_util.debugm(N, task, 'cannot extract text span from object %s:%s',
+                  tobj.major, tobj.minor)
             else
-              lua_util.debugm(N, task, 'extracted text from object %s:%s (%d bytes)',
-                  tobj.major, tobj.minor, bl.len)
-            end
+              bl.data = span_data
+              if bl.len <= 256 then
+                lua_util.debugm(N, task, 'extracted text from object %s:%s: %s',
+                    tobj.major, tobj.minor, bl.data)
+              else
+                lua_util.debugm(N, task, 'extracted text from object %s:%s (%d bytes)',
+                    tobj.major, tobj.minor, bl.len)
+              end
 
-            if bl.len < config.max_processing_size then
-              local ret, obj_or_err = pcall(pdf_text_grammar.match, pdf_text_grammar,
-                  bl.data)
+              if bl.len < config.max_processing_size then
+                local ret, obj_or_err = pcall(pdf_text_grammar.match, pdf_text_grammar,
+                    bl.data)
 
-              if ret then
-                if #obj_or_err == 0 then
-                  lua_util.debugm(N, task, 'empty text match from block: %s', bl.data)
-                end
-                for _, chunk in ipairs(obj_or_err) do
-                  text[#text + 1] = chunk
+                if ret and type(obj_or_err) == 'table' then
+                  if #obj_or_err == 0 then
+                    lua_util.debugm(N, task, 'empty text match from block: %s', bl.data)
+                  end
+                  for _, chunk in ipairs(obj_or_err) do
+                    text[#text + 1] = chunk
+                  end
+                  text[#text + 1] = '\n'
+                  lua_util.debugm(N, task, 'attached %s from content object %s:%s to %s:%s',
+                      obj_or_err, tobj.major, tobj.minor, obj.major, obj.minor)
+                else
+                  lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s',
+                      obj.major, obj.minor, obj_or_err)
                 end
-                text[#text + 1] = '\n'
-                lua_util.debugm(N, task, 'attached %s from content object %s:%s to %s:%s',
-                    obj_or_err, tobj.major, tobj.minor, obj.major, obj.minor)
-              else
-                lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s',
-                    obj.major, obj.minor, obj_or_err)
               end
-
             end
           end
         end
@@ -1504,6 +1569,10 @@ local function search_urls(task, pdf, mpart)
           obj.major, obj.minor)
       return
     end
+    -- Defensive: ensure dict is actually a table we can iterate
+    if type(dict) ~= 'table' then
+      return
+    end
 
     for k, v in pairs(dict) do
       if type(v) == 'table' then
@@ -1511,13 +1580,25 @@ local function search_urls(task, pdf, mpart)
       elseif k == 'URI' then
         v = maybe_dereference_object(v, pdf, task)
         if type(v) == 'string' then
-          local url = rspamd_url.create(task:get_mempool(), v, { 'content' })
+          -- Wrap URL creation in pcall to handle malformed URLs
+          local url_ok, url = pcall(rspamd_url.create, task:get_mempool(), v, { 'content' })
 
-          if url then
+          if url_ok and url then
             lua_util.debugm(N, task, 'found url %s in object %s:%s',
                 v, obj.major, obj.minor)
             task:inject_url(url, mpart)
           end
+        elseif type(v) == 'userdata' then
+          -- Handle rspamd_text objects
+          local str_ok, str_v = pcall(tostring, v)
+          if str_ok and str_v then
+            local url_ok, url = pcall(rspamd_url.create, task:get_mempool(), str_v, { 'content' })
+            if url_ok and url then
+              lua_util.debugm(N, task, 'found url %s in object %s:%s',
+                  str_v, obj.major, obj.minor)
+              task:inject_url(url, mpart)
+            end
+          end
         end
       end
     end
@@ -1525,7 +1606,12 @@ local function search_urls(task, pdf, mpart)
 
   for _, obj in ipairs(pdf.objects) do
     if obj.dict and type(obj.dict) == 'table' then
-      recursive_object_traverse(obj, obj.dict, 0)
+      -- Wrap the traversal in pcall to handle any unexpected errors
+      local ok, err = pcall(recursive_object_traverse, obj, obj.dict, 0)
+      if not ok then
+        lua_util.debugm(N, task, 'error traversing object %s:%s for URLs: %s',
+            obj.major, obj.minor, err)
+      end
     end
   end
 end
@@ -1551,20 +1637,25 @@ local function process_pdf(input, mpart, task)
     local pdf_output = lua_util.shallowcopy(pdf_object)
     local grouped_processors = {}
     for npat, matched_positions in pairs(matches) do
-      local index = pdf_indexes[npat]
+      if type(matched_positions) ~= 'table' then
+        -- Skip malformed match results
+        lua_util.debugm(N, task, 'skipping malformed trie match result: %s', type(matched_positions))
+      else
+        local index = pdf_indexes[npat]
 
-      local proc_key, loc_npat = index[1], index[4]
+        local proc_key, loc_npat = index[1], index[4]
 
-      if not grouped_processors[proc_key] then
-        grouped_processors[proc_key] = {
-          processor_func = processors[proc_key],
-          offsets = {},
-        }
-      end
-      local proc = grouped_processors[proc_key]
-      -- Fill offsets
-      for _, pos in ipairs(matched_positions) do
-        proc.offsets[#proc.offsets + 1] = { pos, loc_npat }
+        if not grouped_processors[proc_key] then
+          grouped_processors[proc_key] = {
+            processor_func = processors[proc_key],
+            offsets = {},
+          }
+        end
+        local proc = grouped_processors[proc_key]
+        -- Fill offsets
+        for _, pos in ipairs(matched_positions) do
+          proc.offsets[#proc.offsets + 1] = { pos, loc_npat }
+        end
       end
     end
 
@@ -1575,7 +1666,11 @@ local function process_pdf(input, mpart, task)
       table.sort(processor.offsets, function(e1, e2)
         return e1[1] < e2[1]
       end)
-      processor.processor_func(input, task, processor.offsets, pdf_object, pdf_output)
+      -- Wrap processor call in pcall to handle any errors gracefully
+      local proc_ok, proc_err = pcall(processor.processor_func, input, task, processor.offsets, pdf_object, pdf_output)
+      if not proc_ok then
+        lua_util.debugm(N, task, "pdf: processor %s failed: %s", name, proc_err)
+      end
     end
 
     pdf_output.flags = {}
@@ -1586,15 +1681,26 @@ local function process_pdf(input, mpart, task)
         -- Trim
       end
 
-      -- Postprocess objects
-      postprocess_pdf_objects(task, input, pdf_object)
+      -- Postprocess objects - wrap in pcall for safety
+      local pp_ok, pp_err = pcall(postprocess_pdf_objects, task, input, pdf_object)
+      if not pp_ok then
+        lua_util.debugm(N, task, "pdf: postprocess_pdf_objects failed: %s", pp_err)
+      end
       pdf_output.objects = pdf_object.objects
       -- Skip text extraction if timeout occurred - partial results would be incorrect
       if config.text_extraction and not pdf_object.timeout_processing then
-        search_text(task, pdf_object, mpart)
+        -- Wrap in pcall for safety
+        local st_ok, st_err = pcall(search_text, task, pdf_object, mpart)
+        if not st_ok then
+          lua_util.debugm(N, task, "pdf: search_text failed: %s", st_err)
+        end
       end
       if config.url_extraction then
-        search_urls(task, pdf_object, mpart, pdf_output)
+        -- Wrap in pcall for safety
+        local su_ok, su_err = pcall(search_urls, task, pdf_object, mpart, pdf_output)
+        if not su_ok then
+          lua_util.debugm(N, task, "pdf: search_urls failed: %s", su_err)
+        end
       end
 
       if config.js_fuzzy and pdf_object.scripts then
@@ -1648,7 +1754,14 @@ end
 
 -- Processes the PDF trailer
 processors.trailer = function(input, task, positions, pdf_object, pdf_output)
+  -- Defensive checks
+  if not positions or #positions == 0 then
+    return
+  end
   local last_pos = positions[#positions]
+  if not last_pos or type(last_pos) ~= 'table' or not last_pos[1] then
+    return
+  end
 
   lua_util.debugm(N, task, 'pdf: process trailer at position %s (%s total length)',
       last_pos, #input)
@@ -1658,44 +1771,68 @@ processors.trailer = function(input, task, positions, pdf_object, pdf_output)
     return
   end
 
-  local last_span = input:span(last_pos[1])
+  -- Wrap span extraction in pcall
+  local span_ok, last_span = pcall(input.span, input, last_pos[1])
+  if not span_ok or not last_span then
+    lua_util.debugm(N, task, 'pdf: cannot extract trailer span')
+    return
+  end
+
   local lines_checked = 0
-  for line in last_span:lines(true) do
-    if line:find('/Encrypt ') then
-      lua_util.debugm(N, task, "pdf: found encrypted line in trailer: %s",
-          line)
-      pdf_output.encrypted = true
-      pdf_object.encrypted = true
-      break
-    end
-    lines_checked = lines_checked + 1
+  -- Wrap lines iteration in pcall
+  local iter_ok, iter_err = pcall(function()
+    for line in last_span:lines(true) do
+      if line:find('/Encrypt ') then
+        lua_util.debugm(N, task, "pdf: found encrypted line in trailer: %s",
+            line)
+        pdf_output.encrypted = true
+        pdf_object.encrypted = true
+        break
+      end
+      lines_checked = lines_checked + 1
 
-    if lines_checked > config.max_pdf_trailer_lines then
-      lua_util.debugm(N, task, "pdf: trailer has too many lines, stop checking")
-      pdf_output.long_trailer = #input - last_pos[1]
-      break
+      if lines_checked > config.max_pdf_trailer_lines then
+        lua_util.debugm(N, task, "pdf: trailer has too many lines, stop checking")
+        pdf_output.long_trailer = #input - last_pos[1]
+        break
+      end
     end
+  end)
+  if not iter_ok then
+    lua_util.debugm(N, task, 'pdf: error iterating trailer lines: %s', iter_err)
   end
 end
 
 processors.suspicious = function(input, task, positions, pdf_object, pdf_output)
+  -- Defensive check for positions
+  if not positions or type(positions) ~= 'table' then
+    return
+  end
+
   local suspicious_factor = 0.0
   local nexec = 0
   local nencoded = 0
   local close_encoded = 0
   local last_encoded
   for _, match in ipairs(positions) do
+    -- Defensive check for match structure
+    if type(match) ~= 'table' or not match[1] or not match[2] then
+      goto continue_suspicious
+    end
+
     if match[2] == 1 then
       -- netsh
       suspicious_factor = suspicious_factor + 0.5
     elseif match[2] == 2 then
       nexec = nexec + 1
     elseif match[2] == 3 then
-      local enc_data = input:sub(match[1] - 2, match[1] - 1)
+      -- Wrap input:sub in pcall for safety
+      local sub_ok, enc_data = pcall(input.sub, input, match[1] - 2, match[1] - 1)
       local legal_escape = false
 
-      if enc_data then
-        enc_data = enc_data:strtoul()
+      if sub_ok and enc_data then
+        local strtoul_ok, strtoul_result = pcall(enc_data.strtoul, enc_data)
+        enc_data = strtoul_ok and strtoul_result or nil
 
         if enc_data then
           -- Legit encode cases are non printable characters (e.g. spaces)
@@ -1718,6 +1855,7 @@ processors.suspicious = function(input, task, positions, pdf_object, pdf_output)
 
       end
     end
+    ::continue_suspicious::
   end
 
   if nencoded > 10 then
@@ -1743,12 +1881,22 @@ processors.suspicious = function(input, task, positions, pdf_object, pdf_output)
 end
 
 local function generic_table_inserter(positions, pdf_object, output_key)
+  -- Defensive checks
+  if not positions or type(positions) ~= 'table' then
+    return
+  end
+  if not pdf_object or type(pdf_object) ~= 'table' then
+    return
+  end
   if not pdf_object[output_key] then
     pdf_object[output_key] = {}
   end
   local shift = #pdf_object[output_key]
   for i, pos in ipairs(positions) do
-    pdf_object[output_key][i + shift] = pos[1]
+    -- Check pos is a table with valid first element
+    if type(pos) == 'table' and pos[1] then
+      pdf_object[output_key][i + shift] = pos[1]
+    end
   end
 end