]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Rework tokenizers initialisation
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Jun 2025 21:00:16 +0000 (22:00 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 22 Jun 2025 21:00:16 +0000 (22:00 +0100)
lualib/rspamadm/mime.lua
src/libserver/cfg_utils.cxx
src/libstat/tokenizers/tokenizer_manager.c
src/lua/lua_config.c

index e0b23e16cfab6f2f5abbfef4659981a502a84006..a20e47e237e5226833b76a7e605115e31b66e754 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 local argparse = require "argparse"
 local ansicolors = require "ansicolors"
@@ -35,94 +35,94 @@ local parser = argparse()
     :require_command(true)
 
 parser:option "-c --config"
-      :description "Path to config file"
-      :argname("<cfg>")
-      :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+    :description "Path to config file"
+    :argname("<cfg>")
+    :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
 parser:mutex(
-    parser:flag "-j --json"
-          :description "JSON output",
-    parser:flag "-U --ucl"
-          :description "UCL output",
-    parser:flag "-M --messagepack"
-          :description "MessagePack output"
+  parser:flag "-j --json"
+  :description "JSON output",
+  parser:flag "-U --ucl"
+  :description "UCL output",
+  parser:flag "-M --messagepack"
+  :description "MessagePack output"
 )
 parser:flag "-C --compact"
-      :description "Use compact format"
+    :description "Use compact format"
 parser:flag "--no-file"
-      :description "Do not print filename"
+    :description "Do not print filename"
 
 -- Extract subcommand
 local extract = parser:command "extract ex e"
-                      :description "Extracts data from MIME messages"
+    :description "Extracts data from MIME messages"
 extract:argument "file"
-       :description "File to process"
-       :argname "<file>"
-       :args "+"
+    :description "File to process"
+    :argname "<file>"
+    :args "+"
 
 extract:flag "-t --text"
-       :description "Extracts plain text data from a message"
+    :description "Extracts plain text data from a message"
 extract:flag "-H --html"
-       :description "Extracts htm data from a message"
+    :description "Extracts htm data from a message"
 extract:option "-o --output"
-       :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')"
-       :argname("<type>")
-       :convert {
-  raw = "raw",
-  content = "content",
-  oneline = "content_oneline",
-  decoded = "raw_parsed",
-  decoded_utf = "raw_utf"
-}
-       :default "content"
+    :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')"
+    :argname("<type>")
+    :convert {
+      raw = "raw",
+      content = "content",
+      oneline = "content_oneline",
+      decoded = "raw_parsed",
+      decoded_utf = "raw_utf"
+    }
+    :default "content"
 extract:flag "-w --words"
-       :description "Extracts words"
+    :description "Extracts words"
 extract:flag "-p --part"
-       :description "Show part info"
+    :description "Show part info"
 extract:flag "-s --structure"
-       :description "Show structure info (e.g. HTML tags)"
+    :description "Show structure info (e.g. HTML tags)"
 extract:flag "-i --invisible"
-       :description "Show invisible content for HTML parts"
+    :description "Show invisible content for HTML parts"
 extract:option "-F --words-format"
-       :description "Words format ('stem', 'norm', 'raw', 'full')"
-       :argname("<type>")
-       :convert {
-  stem = "stem",
-  norm = "norm",
-  raw = "raw",
-  full = "full",
-}
-       :default "stem"
+    :description "Words format ('stem', 'norm', 'raw', 'full')"
+    :argname("<type>")
+    :convert {
+      stem = "stem",
+      norm = "norm",
+      raw = "raw",
+      full = "full",
+    }
+    :default "stem"
 
 local stat = parser:command "stat st s"
-                   :description "Extracts statistical data from MIME messages"
+    :description "Extracts statistical data from MIME messages"
 stat:argument "file"
     :description "File to process"
     :argname "<file>"
     :args "+"
 stat:mutex(
-    stat:flag "-m --meta"
-        :description "Lua metatokens",
-    stat:flag "-b --bayes"
-        :description "Bayes tokens",
-    stat:flag "-F --fuzzy"
-        :description "Fuzzy hashes"
+  stat:flag "-m --meta"
+  :description "Lua metatokens",
+  stat:flag "-b --bayes"
+  :description "Bayes tokens",
+  stat:flag "-F --fuzzy"
+  :description "Fuzzy hashes"
 )
 stat:flag "-s --shingles"
     :description "Show shingles for fuzzy hashes"
 
 local urls = parser:command "urls url u"
-                   :description "Extracts URLs from MIME messages"
+    :description "Extracts URLs from MIME messages"
 urls:argument "file"
     :description "File to process"
     :argname "<file>"
     :args "+"
 urls:mutex(
-    urls:flag "-t --tld"
-        :description "Get TLDs only",
-    urls:flag "-H --host"
-        :description "Get hosts only",
-    urls:flag "-f --full"
-        :description "Show piecewise urls as processed by Rspamd"
+  urls:flag "-t --tld"
+  :description "Get TLDs only",
+  urls:flag "-H --host"
+  :description "Get hosts only",
+  urls:flag "-f --full"
+  :description "Show piecewise urls as processed by Rspamd"
 )
 
 urls:flag "-u --unique"
@@ -135,75 +135,75 @@ urls:flag "-r --reverse"
     :description "Reverse sort order"
 
 local modify = parser:command "modify mod m"
-                     :description "Modifies MIME message"
+    :description "Modifies MIME message"
 modify:argument "file"
-      :description "File to process"
-      :argname "<file>"
-      :args "+"
+    :description "File to process"
+    :argname "<file>"
+    :args "+"
 
 modify:option "-a --add-header"
-      :description "Adds specific header"
-      :argname "<header=value>"
-      :count "*"
+    :description "Adds specific header"
+    :argname "<header=value>"
+    :count "*"
 modify:option "-r --remove-header"
-      :description "Removes specific header (all occurrences)"
-      :argname "<header>"
-      :count "*"
+    :description "Removes specific header (all occurrences)"
+    :argname "<header>"
+    :count "*"
 modify:option "-R --rewrite-header"
-      :description "Rewrites specific header, uses Lua string.format pattern"
-      :argname "<header=pattern>"
-      :count "*"
+    :description "Rewrites specific header, uses Lua string.format pattern"
+    :argname "<header=pattern>"
+    :count "*"
 modify:option "-t --text-footer"
-      :description "Adds footer to text/plain parts from a specific file"
-      :argname "<file>"
+    :description "Adds footer to text/plain parts from a specific file"
+    :argname "<file>"
 modify:option "-H --html-footer"
-      :description "Adds footer to text/html parts from a specific file"
-      :argname "<file>"
+    :description "Adds footer to text/html parts from a specific file"
+    :argname "<file>"
 
 local strip = parser:command "strip"
-                    :description "Strip attachments from a message"
+    :description "Strip attachments from a message"
 strip:argument "file"
-     :description "File to process"
-     :argname "<file>"
-     :args "+"
+    :description "File to process"
+    :argname "<file>"
+    :args "+"
 strip:flag "-i --keep-images"
-     :description "Keep images"
+    :description "Keep images"
 strip:option "--min-text-size"
-     :description "Minimal text size to keep"
-     :argname "<size>"
-     :convert(tonumber)
-     :default(0)
+    :description "Minimal text size to keep"
+    :argname "<size>"
+    :convert(tonumber)
+    :default(0)
 strip:option "--max-text-size"
-     :description "Max text size to keep"
-     :argname "<size>"
-     :convert(tonumber)
-     :default(math.huge)
+    :description "Max text size to keep"
+    :argname "<size>"
+    :convert(tonumber)
+    :default(math.huge)
 
 local anonymize = parser:command "anonymize"
-                        :description "Try to remove sensitive information from a message"
+    :description "Try to remove sensitive information from a message"
 anonymize:argument "file"
-         :description "File to process"
-         :argname "<file>"
-         :args "+"
+    :description "File to process"
+    :argname "<file>"
+    :args "+"
 anonymize:option "--exclude-header -X"
-         :description "Exclude specific headers from anonymization"
-         :argname "<header>"
-         :count "*"
+    :description "Exclude specific headers from anonymization"
+    :argname "<header>"
+    :count "*"
 anonymize:option "--include-header -I"
-         :description "Include specific headers from anonymization"
-         :argname "<header>"
-         :count "*"
+    :description "Include specific headers from anonymization"
+    :argname "<header>"
+    :count "*"
 anonymize:flag "--gpt"
-         :description "Use LLM model for anonymization (requires GPT plugin to be configured)"
+    :description "Use LLM model for anonymization (requires GPT plugin to be configured)"
 anonymize:option "--model"
-         :description "Model to use for anonymization"
-         :argname "<model>"
+    :description "Model to use for anonymization"
+    :argname "<model>"
 anonymize:option "--prompt"
-         :description "Prompt to use for anonymization"
-         :argname "<prompt>"
+    :description "Prompt to use for anonymization"
+    :argname "<prompt>"
 
 local sign = parser:command "sign"
-                   :description "Performs DKIM signing"
+    :description "Performs DKIM signing"
 sign:argument "file"
     :description "File to process"
     :argname "<file>"
@@ -225,33 +225,33 @@ sign:option "-t --type"
     :description "ARC or DKIM signing"
     :argname("<arc|dkim>")
     :convert {
-  ['arc'] = 'arc',
-  ['dkim'] = 'dkim',
-}
+      ['arc'] = 'arc',
+      ['dkim'] = 'dkim',
+    }
     :default 'dkim'
 sign:option "-o --output"
     :description "Output format"
     :argname("<message|signature>")
     :convert {
-  ['message'] = 'message',
-  ['signature'] = 'signature',
-}
+      ['message'] = 'message',
+      ['signature'] = 'signature',
+    }
     :default 'message'
 
 local dump = parser:command "dump"
-                   :description "Dumps a raw message in different formats"
+    :description "Dumps a raw message in different formats"
 dump:argument "file"
     :description "File to process"
     :argname "<file>"
     :args "+"
 -- Duplicate format for convenience
 dump:mutex(
-    parser:flag "-j --json"
-          :description "JSON output",
-    parser:flag "-U --ucl"
-          :description "UCL output",
-    parser:flag "-M --messagepack"
-          :description "MessagePack output"
+  parser:flag "-j --json"
+  :description "JSON output",
+  parser:flag "-U --ucl"
+  :description "UCL output",
+  parser:flag "-M --messagepack"
+  :description "MessagePack output"
 )
 dump:flag "-s --split"
     :description "Split the output file contents such that no content is embedded"
@@ -260,7 +260,7 @@ dump:option "-o --outdir"
     :description "Output directory"
     :argname("<directory>")
 
-local function load_config(opts)
+local function load_config(opts, load_tokenizers)
   local _r, err = rspamd_config:load_ucl(opts['config'])
 
   if not _r then
@@ -273,6 +273,23 @@ local function load_config(opts)
     rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
     os.exit(1)
   end
+
+  -- Load custom tokenizers if requested
+  if load_tokenizers then
+    local success, tokenizer_err = rspamd_config:load_custom_tokenizers()
+    if not success then
+      rspamd_logger.errx('cannot load custom tokenizers: %s', tokenizer_err or 'unknown error')
+      -- Don't exit here as custom tokenizers are optional
+      rspamd_logger.warnx('proceeding without custom tokenizers')
+    end
+  end
+end
+
+-- Helper function to ensure proper cleanup of tokenizers
+local function cleanup_tokenizers()
+  if rspamd_config then
+    rspamd_config:unload_custom_tokenizers()
+  end
 end
 
 local function load_task(_, fname)
@@ -288,13 +305,13 @@ local function load_task(_, fname)
 
   if not res then
     parser:error(string.format('cannot read message from %s: %s', fname,
-        task))
+      task))
     return nil
   end
 
   if not task:process_message() then
     parser:error(string.format('cannot read message from %s: %s', fname,
-        'failed to parse'))
+      'failed to parse'))
     return nil
   end
 
@@ -335,7 +352,6 @@ local function print_elts(elts, opts, func)
     io.write(ucl.to_format(elts, output_fmt(opts)))
   else
     fun.each(function(fname, elt)
-
       if not opts.json and not opts.ucl then
         if func then
           elt = fun.map(func, elt)
@@ -357,7 +373,7 @@ local function extract_handler(opts)
 
   if opts.words then
     -- Enable stemming and urls detection
-    load_config(opts)
+    load_config(opts, true) -- Load with custom tokenizers
     rspamd_url.init(rspamd_config:get_tld_path())
     rspamd_config:init_subsystem('langdet')
   end
@@ -372,39 +388,38 @@ local function extract_handler(opts)
 
       if not opts.json and not opts.ucl then
         table.insert(out,
-            rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s',
-                part:get_mimepart():get_digest():sub(1, 8),
-                t,
-                part:get_language(),
-                part:get_length(), part:get_raw_length(),
-                part:get_words_count()))
+          rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s',
+            part:get_mimepart():get_digest():sub(1, 8),
+            t,
+            part:get_language(),
+            part:get_length(), part:get_raw_length(),
+            part:get_words_count()))
         table.insert(out,
-            rspamd_logger.slog('Stats: %s',
-                fun.foldl(function(acc, k, v)
-                  if acc ~= '' then
-                    return string.format('%s, %s:%s', acc, k, v)
-                  else
-                    return string.format('%s:%s', k, v)
-                  end
-                end, '', part:get_stats())))
+          rspamd_logger.slog('Stats: %s',
+            fun.foldl(function(acc, k, v)
+              if acc ~= '' then
+                return string.format('%s, %s:%s', acc, k, v)
+              else
+                return string.format('%s:%s', k, v)
+              end
+            end, '', part:get_stats())))
       end
     end
   end
 
   local function maybe_print_mime_part_info(part, out)
     if opts.part then
-
       if not opts.json and not opts.ucl then
         local mtype, msubtype = part:get_type()
         local det_mtype, det_msubtype = part:get_detected_type()
         table.insert(out,
-            rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s',
-                part:get_digest():sub(1, 8),
-                mtype, msubtype,
-                det_mtype, det_msubtype,
-                part:get_filename(),
-                part:get_detected_ext(),
-                part:get_length()))
+          rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s',
+            part:get_digest():sub(1, 8),
+            mtype, msubtype,
+            det_mtype, det_msubtype,
+            part:get_filename(),
+            part:get_detected_ext(),
+            part:get_length()))
       end
     end
   end
@@ -416,17 +431,17 @@ local function extract_handler(opts)
       return table.concat(words, ' ')
     else
       return table.concat(
-          fun.totable(
-              fun.map(function(w)
-                -- [1] - stemmed word
-                -- [2] - normalised word
-                -- [3] - raw word
-                -- [4] - flags (table of strings)
-                return string.format('%s|%s|%s(%s)',
-                    w[3], w[2], w[1], table.concat(w[4], ','))
-              end, words)
-          ),
-          ' '
+        fun.totable(
+          fun.map(function(w)
+            -- [1] - stemmed word
+            -- [2] - normalised word
+            -- [3] - raw word
+            -- [4] - flags (table of strings)
+            return string.format('%s|%s|%s(%s)',
+              w[3], w[2], w[1], table.concat(w[4], ','))
+          end, words)
+        ),
+        ' '
       )
     end
   end
@@ -443,7 +458,7 @@ local function extract_handler(opts)
     if opts.words then
       local how_words = opts['words_format'] or 'stem'
       table.insert(out_elts[fname], 'meta_words: ' ..
-          print_words(task:get_meta_words(how_words), how_words == 'full'))
+        print_words(task:get_meta_words(how_words), how_words == 'full'))
     end
 
     if opts.text or opts.html then
@@ -466,7 +481,7 @@ local function extract_handler(opts)
           if opts.words then
             local how_words = opts['words_format'] or 'stem'
             table.insert(out_elts[fname], print_words(part:get_words(how_words),
-                how_words == 'full'))
+              how_words == 'full'))
           else
             table.insert(out_elts[fname], tostring(part:get_content(how)))
           end
@@ -480,7 +495,7 @@ local function extract_handler(opts)
           if opts.words then
             local how_words = opts['words_format'] or 'stem'
             table.insert(out_elts[fname], print_words(part:get_words(how_words),
-                how_words == 'full'))
+              how_words == 'full'))
           else
             if opts.structure then
               local hc = part:get_html()
@@ -489,11 +504,11 @@ local function extract_handler(opts)
                 local fun = require "fun"
                 if type(elt) == 'table' then
                   return table.concat(fun.totable(
-                      fun.map(
-                          function(t)
-                            return rspamd_logger.slog("%s", t)
-                          end,
-                          elt)), '\n')
+                    fun.map(
+                      function(t)
+                        return rspamd_logger.slog("%s", t)
+                      end,
+                      elt)), '\n')
                 else
                   return rspamd_logger.slog("%s", elt)
                 end
@@ -524,7 +539,7 @@ local function extract_handler(opts)
             if opts.invisible then
               local hc = part:get_html()
               table.insert(out_elts[fname], string.format('invisible content: %s',
-                  tostring(hc:get_invisible())))
+                tostring(hc:get_invisible())))
             end
           end
         end
@@ -544,13 +559,18 @@ local function extract_handler(opts)
   for _, task in ipairs(tasks) do
     task:destroy()
   end
+
+  -- Cleanup custom tokenizers if they were loaded
+  if opts.words then
+    cleanup_tokenizers()
+  end
 end
 
 local function stat_handler(opts)
   local fun = require "fun"
   local out_elts = {}
 
-  load_config(opts)
+  load_config(opts, true)                      -- Load with custom tokenizers for stat generation
   rspamd_url.init(rspamd_config:get_tld_path())
   rspamd_config:init_subsystem('langdet,stat') -- Needed to gen stat tokens
 
@@ -571,10 +591,10 @@ local function stat_handler(opts)
       out_elts[fname] = bt
       process_func = function(e)
         return string.format('%s (%d): "%s"+"%s", [%s]', e.data, e.win, e.t1 or "",
-            e.t2 or "", table.concat(fun.totable(
-                fun.map(function(k)
-                  return k
-                end, e.flags)), ","))
+          e.t2 or "", table.concat(fun.totable(
+            fun.map(function(k)
+              return k
+            end, e.flags)), ","))
       end
     elseif opts.fuzzy then
       local parts = task:get_parts() or {}
@@ -601,16 +621,16 @@ local function stat_handler(opts)
               digest = digest,
               shingles = shingles,
               type = string.format('%s/%s',
-                  ({ part:get_type() })[1],
-                  ({ part:get_type() })[2])
+                ({ part:get_type() })[1],
+                ({ part:get_type() })[2])
             })
           else
             table.insert(out_elts[fname], {
               digest = part:get_digest(),
               file = part:get_filename(),
               type = string.format('%s/%s',
-                  ({ part:get_type() })[1],
-                  ({ part:get_type() })[2])
+                ({ part:get_type() })[1],
+                ({ part:get_type() })[2])
             })
           end
         end
@@ -621,10 +641,13 @@ local function stat_handler(opts)
   end
 
   print_elts(out_elts, opts, process_func)
+
+  -- Cleanup custom tokenizers
+  cleanup_tokenizers()
 end
 
 local function urls_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- URLs don't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
   local out_elts = {}
 
@@ -764,7 +787,7 @@ local function newline(task)
 end
 
 local function modify_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- Modification doesn't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
 
   local function read_file(file)
@@ -804,10 +827,10 @@ local function modify_handler(opts)
         if hname == name then
           local new_value = string.format(hpattern, hdr.decoded)
           new_value = string.format('%s:%s%s',
-              name, hdr.separator,
-              rspamd_util.fold_header(name,
-                  rspamd_util.mime_header_encode(new_value),
-                  task:get_newlines_type()))
+            name, hdr.separator,
+            rspamd_util.fold_header(name,
+              rspamd_util.mime_header_encode(new_value),
+              task:get_newlines_type()))
           out[#out + 1] = new_value
           return
         end
@@ -816,12 +839,12 @@ local function modify_handler(opts)
       if rewrite.need_rewrite_ct then
         if name:lower() == 'content-type' then
           local nct = string.format('%s: %s/%s; charset=utf-8',
-              'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype)
+            'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype)
           out[#out + 1] = nct
           return
         elseif name:lower() == 'content-transfer-encoding' then
           out[#out + 1] = string.format('%s: %s',
-              'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+            'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
           seen_cte = true
           return
         end
@@ -837,13 +860,13 @@ local function modify_handler(opts)
 
       if hname and hvalue then
         out[#out + 1] = string.format('%s: %s', hname,
-            rspamd_util.fold_header(hname, hvalue, task:get_newlines_type()))
+          rspamd_util.fold_header(hname, hvalue, task:get_newlines_type()))
       end
     end
 
     if not seen_cte and rewrite.need_rewrite_ct then
       out[#out + 1] = string.format('%s: %s',
-          'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+        'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
     end
 
     -- End of headers
@@ -883,7 +906,7 @@ local function modify_handler(opts)
 end
 
 local function sign_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- Signing doesn't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
 
   local lua_dkim = require("lua_ffi").dkim
@@ -927,11 +950,11 @@ local function sign_handler(opts)
       io.flush()
     else
       local dkim_hdr = string.format('%s: %s%s',
-          'DKIM-Signature',
-          rspamd_util.fold_header('DKIM-Signature',
-              rspamd_util.mime_header_encode(sig),
-              task:get_newlines_type()),
-          newline(task))
+        'DKIM-Signature',
+        rspamd_util.fold_header('DKIM-Signature',
+          rspamd_util.mime_header_encode(sig),
+          task:get_newlines_type()),
+        newline(task))
       io.write(dkim_hdr)
       io.flush()
       task:get_content():save_in_file(1)
@@ -942,7 +965,7 @@ local function sign_handler(opts)
 end
 
 local function strip_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- Stripping doesn't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
 
   for _, fname in ipairs(opts.file) do
@@ -998,7 +1021,7 @@ local function strip_handler(opts)
 end
 
 local function anonymize_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- Anonymization doesn't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
 
   for _, fname in ipairs(opts.file) do
@@ -1103,7 +1126,7 @@ local function get_dump_content(task, opts, fname)
 end
 
 local function dump_handler(opts)
-  load_config(opts)
+  load_config(opts, false) -- Dumping doesn't need custom tokenizers
   rspamd_url.init(rspamd_config:get_tld_path())
 
   for _, fname in ipairs(opts.file) do
index 3fd322a1ea4041d6ad7a0950152b2f315e1de709..c7bb202108dd01095d9fb309621128aef65fb357 100644 (file)
@@ -826,6 +826,65 @@ rspamd_adjust_clocks_resolution(struct rspamd_config *cfg)
 #endif
 }
 
+extern "C" {
+
+gboolean
+rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err)
+{
+       /* Load custom tokenizers */
+       const ucl_object_t *custom_tokenizers = ucl_object_lookup_path(cfg->cfg_ucl_obj,
+                                                                                                                                  "options.custom_tokenizers");
+       if (custom_tokenizers != NULL) {
+               msg_info_config("loading custom tokenizers");
+
+               if (!cfg->tokenizer_manager) {
+                       cfg->tokenizer_manager = rspamd_tokenizer_manager_new(cfg->cfg_pool);
+               }
+
+               ucl_object_iter_t it = ucl_object_iterate_new(custom_tokenizers);
+               const ucl_object_t *tok_obj;
+               const char *tok_name;
+
+               while ((tok_obj = ucl_object_iterate_safe(it, true)) != NULL) {
+                       tok_name = ucl_object_key(tok_obj);
+                       GError *local_err = NULL;
+
+                       if (!rspamd_tokenizer_manager_load_tokenizer(cfg->tokenizer_manager,
+                                                                                                                tok_name, tok_obj, &local_err)) {
+                               msg_err_config("failed to load custom tokenizer '%s': %s",
+                                                          tok_name, local_err ? local_err->message : "unknown error");
+
+                               if (err && !*err) {
+                                       *err = g_error_copy(local_err);
+                               }
+
+                               if (local_err) {
+                                       g_error_free(local_err);
+                               }
+
+                               ucl_object_iterate_free(it);
+                               return FALSE;
+                       }
+               }
+               ucl_object_iterate_free(it);
+
+               msg_info_config("loaded custom tokenizers successfully");
+       }
+
+       return TRUE;
+}
+
+void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg)
+{
+       if (cfg->tokenizer_manager) {
+               msg_info_config("unloading custom tokenizers");
+               rspamd_tokenizer_manager_destroy(cfg->tokenizer_manager);
+               cfg->tokenizer_manager = NULL;
+       }
+}
+
+}// extern "C"
+
 /*
  * Perform post load actions
  */
@@ -946,35 +1005,18 @@ rspamd_config_post_load(struct rspamd_config *cfg,
                        return FALSE;
                }
 
-               /* Load custom tokenizers */
-               const ucl_object_t *custom_tokenizers = ucl_object_lookup_path(cfg->cfg_ucl_obj,
-                                                                                                                                          "options.custom_tokenizers");
-               if (custom_tokenizers != NULL) {
-                       msg_info_config("loading custom tokenizers");
-                       cfg->tokenizer_manager = rspamd_tokenizer_manager_new(cfg->cfg_pool);
-
-                       ucl_object_iter_t it = ucl_object_iterate_new(custom_tokenizers);
-                       const ucl_object_t *tok_obj;
-                       const char *tok_name;
-
-                       while ((tok_obj = ucl_object_iterate_safe(it, true)) != NULL) {
-                               tok_name = ucl_object_key(tok_obj);
-                               GError *err = NULL;
-
-                               if (!rspamd_tokenizer_manager_load_tokenizer(cfg->tokenizer_manager,
-                                                                                                                        tok_name, tok_obj, &err)) {
-                                       msg_err_config("failed to load custom tokenizer '%s': %s",
-                                                                  tok_name, err ? err->message : "unknown error");
-                                       if (err) {
-                                               g_error_free(err);
-                                       }
+               /* Load custom tokenizers using the new function */
+               GError *tokenizer_err = NULL;
+               if (!rspamd_config_load_custom_tokenizers(cfg, &tokenizer_err)) {
+                       msg_err_config("failed to load custom tokenizers: %s",
+                                                  tokenizer_err ? tokenizer_err->message : "unknown error");
+                       if (tokenizer_err) {
+                               g_error_free(tokenizer_err);
+                       }
 
-                                       if (opts & RSPAMD_CONFIG_INIT_VALIDATE) {
-                                               ret = tl::make_unexpected(fmt::format("failed to load custom tokenizer '{}'", tok_name));
-                                       }
-                               }
+                       if (opts & RSPAMD_CONFIG_INIT_VALIDATE) {
+                               ret = tl::make_unexpected(std::string{"failed to load custom tokenizers"});
                        }
-                       ucl_object_iterate_free(it);
                }
        }
 
index b9bfe0e6f9220e36e44d63f5b2497bcafa159937..e6fb5e8d8cd5b9285c1a05381175ff8b278f51c5 100644 (file)
@@ -100,9 +100,12 @@ rspamd_tokenizer_manager_new(rspamd_mempool_t *pool)
                                                                  (rspamd_mempool_destruct_t) g_hash_table_unref,
                                                                  mgr->tokenizers);
        rspamd_mempool_add_destructor(pool,
-                                                                 (rspamd_mempool_destruct_t) g_array_free,
+                                                                 (rspamd_mempool_destruct_t) rspamd_array_free_hard,
                                                                  mgr->detection_order);
 
+       msg_info_tokenizer("created custom tokenizer manager with default confidence threshold %.3f",
+                                          mgr->default_threshold);
+
        return mgr;
 }
 
@@ -131,6 +134,8 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
        g_assert(name != NULL);
        g_assert(config != NULL);
 
+       msg_info_tokenizer("starting to load custom tokenizer '%s'", name);
+
        /* Check if enabled */
        elt = ucl_object_lookup(config, "enabled");
        if (elt && ucl_object_type(elt) == UCL_BOOLEAN) {
@@ -138,7 +143,7 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
        }
 
        if (!enabled) {
-               msg_info_tokenizer("custom tokenizer %s is disabled", name);
+               msg_info_tokenizer("custom tokenizer '%s' is disabled", name);
                return TRUE;
        }
 
@@ -150,14 +155,17 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
                return FALSE;
        }
        path = ucl_object_tostring(elt);
+       msg_info_tokenizer("custom tokenizer '%s' will be loaded from path: %s", name, path);
 
        /* Get priority */
        elt = ucl_object_lookup(config, "priority");
        if (elt) {
                priority = ucl_object_todouble(elt);
        }
+       msg_info_tokenizer("custom tokenizer '%s' priority set to %.1f", name, priority);
 
        /* Load the shared library */
+       msg_info_tokenizer("loading shared library for custom tokenizer '%s'", name);
        handle = dlopen(path, RTLD_NOW | RTLD_LOCAL);
        if (!handle) {
                g_set_error(err, g_quark_from_static_string("tokenizer"),
@@ -165,8 +173,10 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
                                        name, path, dlerror());
                return FALSE;
        }
+       msg_info_tokenizer("successfully loaded shared library for custom tokenizer '%s'", name);
 
        /* Get the API entry point */
+       msg_info_tokenizer("looking up API entry point for custom tokenizer '%s'", name);
        get_api = (rspamd_tokenizer_get_api_func) dlsym(handle, "rspamd_tokenizer_get_api");
        if (!get_api) {
                dlclose(handle);
@@ -177,6 +187,7 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
        }
 
        /* Get the API */
+       msg_info_tokenizer("calling API entry point for custom tokenizer '%s'", name);
        api = get_api();
        if (!api) {
                dlclose(handle);
@@ -184,8 +195,11 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
                                        EINVAL, "tokenizer %s returned NULL API", name);
                return FALSE;
        }
+       msg_info_tokenizer("successfully obtained API from custom tokenizer '%s'", name);
 
        /* Check API version */
+       msg_info_tokenizer("checking API version for custom tokenizer '%s' (got %u, expected %u)",
+                                          name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION);
        if (api->api_version != RSPAMD_CUSTOM_TOKENIZER_API_VERSION) {
                dlclose(handle);
                g_set_error(err, g_quark_from_static_string("tokenizer"),
@@ -212,13 +226,18 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
        /* Get minimum confidence */
        if (api->get_min_confidence) {
                tok->min_confidence = api->get_min_confidence();
+               msg_info_tokenizer("custom tokenizer '%s' provides minimum confidence threshold: %.3f",
+                                                  name, tok->min_confidence);
        }
        else {
                tok->min_confidence = mgr->default_threshold;
+               msg_info_tokenizer("custom tokenizer '%s' using default confidence threshold: %.3f",
+                                                  name, tok->min_confidence);
        }
 
        /* Initialize the tokenizer */
        if (api->init) {
+               msg_info_tokenizer("initializing custom tokenizer '%s'", name);
                error_buf[0] = '\0';
                if (api->init(tok->config, error_buf, sizeof(error_buf)) != 0) {
                        g_set_error(err, g_quark_from_static_string("tokenizer"),
@@ -227,6 +246,10 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
                        rspamd_custom_tokenizer_dtor(tok);
                        return FALSE;
                }
+               msg_info_tokenizer("successfully initialized custom tokenizer '%s'", name);
+       }
+       else {
+               msg_info_tokenizer("custom tokenizer '%s' does not require initialization", name);
        }
 
        /* Add to manager */
@@ -235,8 +258,10 @@ rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr,
 
        /* Re-sort by priority */
        g_array_sort(mgr->detection_order, rspamd_custom_tokenizer_priority_cmp);
+       msg_info_tokenizer("custom tokenizer '%s' registered and sorted by priority (total tokenizers: %u)",
+                                          name, mgr->detection_order->len);
 
-       msg_info_tokenizer("loaded custom tokenizer %s (priority %.0f) from %s",
+       msg_info_tokenizer("successfully loaded custom tokenizer '%s' (priority %.1f) from %s",
                                           name, priority, path);
 
        return TRUE;
@@ -256,6 +281,8 @@ rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr,
        g_assert(mgr != NULL);
        g_assert(text != NULL);
 
+       msg_debug_tokenizer("starting tokenizer detection for text of length %zu", len);
+
        if (confidence) {
                *confidence = 0.0;
        }
@@ -266,6 +293,7 @@ rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr,
 
        /* If we have a language hint, try to find a tokenizer for that language first */
        if (lang_hint) {
+               msg_info_tokenizer("trying to find tokenizer for language hint: %s", lang_hint);
                for (i = 0; i < mgr->detection_order->len; i++) {
                        tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i);
 
@@ -276,11 +304,16 @@ rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr,
                        /* Check if this tokenizer handles the hinted language */
                        const char *tok_lang = tok->api->get_language_hint();
                        if (tok_lang && g_ascii_strcasecmp(tok_lang, lang_hint) == 0) {
+                               msg_info_tokenizer("found tokenizer '%s' for language hint '%s'", tok->name, lang_hint);
                                /* Found a tokenizer for this language, check if it actually detects it */
                                if (tok->api->detect_language) {
                                        conf = tok->api->detect_language(text, len);
+                                       msg_info_tokenizer("tokenizer '%s' confidence for hinted language: %.3f (threshold: %.3f)",
+                                                                          tok->name, conf, tok->min_confidence);
                                        if (conf >= tok->min_confidence) {
                                                /* Use this tokenizer */
+                                               msg_info_tokenizer("using tokenizer '%s' for language hint '%s' with confidence %.3f",
+                                                                                  tok->name, lang_hint, conf);
                                                if (confidence) {
                                                        *confidence = conf;
                                                }
@@ -292,35 +325,52 @@ rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr,
                                }
                        }
                }
+               msg_info_tokenizer("no suitable tokenizer found for language hint '%s', falling back to general detection", lang_hint);
        }
 
        /* Try each tokenizer in priority order */
+       msg_info_tokenizer("trying %u tokenizers for general detection", mgr->detection_order->len);
        for (i = 0; i < mgr->detection_order->len; i++) {
                tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i);
 
                if (!tok->enabled || !tok->api->detect_language) {
+                       msg_debug_tokenizer("skipping tokenizer '%s' (enabled: %s, has detect_language: %s)",
+                                                               tok->name, tok->enabled ? "yes" : "no",
+                                                               tok->api->detect_language ? "yes" : "no");
                        continue;
                }
 
                conf = tok->api->detect_language(text, len);
+               msg_info_tokenizer("tokenizer '%s' detection confidence: %.3f (threshold: %.3f, current best: %.3f)",
+                                                  tok->name, conf, tok->min_confidence, best_conf);
 
                if (conf > best_conf && conf >= tok->min_confidence) {
                        best_conf = conf;
                        best_tok = tok;
+                       msg_info_tokenizer("tokenizer '%s' is new best with confidence %.3f", tok->name, best_conf);
 
                        /* Early exit if very confident */
                        if (conf >= 0.95) {
+                               msg_info_tokenizer("very high confidence (%.3f >= 0.95), using tokenizer '%s' immediately",
+                                                                  conf, tok->name);
                                break;
                        }
                }
        }
 
-       if (confidence && best_tok) {
-               *confidence = best_conf;
-       }
+       if (best_tok) {
+               msg_info_tokenizer("selected tokenizer '%s' with confidence %.3f", best_tok->name, best_conf);
+               if (confidence) {
+                       *confidence = best_conf;
+               }
 
-       if (detected_lang_hint && best_tok && best_tok->api->get_language_hint) {
-               *detected_lang_hint = best_tok->api->get_language_hint();
+               if (detected_lang_hint && best_tok->api->get_language_hint) {
+                       *detected_lang_hint = best_tok->api->get_language_hint();
+                       msg_info_tokenizer("detected language hint: %s", *detected_lang_hint);
+               }
+       }
+       else {
+               msg_info_tokenizer("no suitable tokenizer found during detection");
        }
 
        return best_tok;
index f52eae44febdc334f8295ce7a28685350a22b4a0..7b3a156cd062c173589e3ed3969988158ac57dd6 100644 (file)
 #include "utlist.h"
 #include <math.h>
 
+/* Forward declarations for custom tokenizer functions */
+gboolean rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err);
+void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg);
+
 /***
  * This module is used to configure rspamd and is normally available as global
  * variable named `rspamd_config`. Unlike other modules, it is not necessary to
@@ -862,6 +866,19 @@ LUA_FUNCTION_DEF(config, get_dns_max_requests);
  */
 LUA_FUNCTION_DEF(config, get_dns_timeout);
 
+/***
+ * @method rspamd_config:load_custom_tokenizers()
+ * Loads custom tokenizers from configuration
+ * @return {boolean} true if successful
+ */
+LUA_FUNCTION_DEF(config, load_custom_tokenizers);
+
+/***
+ * @method rspamd_config:unload_custom_tokenizers()
+ * Unloads custom tokenizers and frees memory
+ */
+LUA_FUNCTION_DEF(config, unload_custom_tokenizers);
+
 static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF(config, get_module_opt),
        LUA_INTERFACE_DEF(config, get_mempool),
@@ -937,6 +954,8 @@ static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF(config, get_tld_path),
        LUA_INTERFACE_DEF(config, get_dns_max_requests),
        LUA_INTERFACE_DEF(config, get_dns_timeout),
+       LUA_INTERFACE_DEF(config, load_custom_tokenizers),
+       LUA_INTERFACE_DEF(config, unload_custom_tokenizers),
        {"__tostring", rspamd_lua_class_tostring},
        {"__newindex", lua_config_newindex},
        {NULL, NULL}};
@@ -4485,11 +4504,14 @@ lua_config_init_subsystem(lua_State *L)
                nparts = g_strv_length(parts);
 
                for (i = 0; i < nparts; i++) {
-                       if (strcmp(parts[i], "filters") == 0) {
+                       const char *str = parts[i];
+
+                       /* TODO: total shit, rework some day */
+                       if (strcmp(str, "filters") == 0) {
                                rspamd_lua_post_load_config(cfg);
                                rspamd_init_filters(cfg, false, false);
                        }
-                       else if (strcmp(parts[i], "langdet") == 0) {
+                       else if (strcmp(str, "langdet") == 0) {
                                if (!cfg->lang_det) {
                                        cfg->lang_det = rspamd_language_detector_init(cfg);
                                        rspamd_mempool_add_destructor(cfg->cfg_pool,
@@ -4497,10 +4519,10 @@ lua_config_init_subsystem(lua_State *L)
                                                                                                  cfg->lang_det);
                                }
                        }
-                       else if (strcmp(parts[i], "stat") == 0) {
+                       else if (strcmp(str, "stat") == 0) {
                                rspamd_stat_init(cfg, NULL);
                        }
-                       else if (strcmp(parts[i], "dns") == 0) {
+                       else if (strcmp(str, "dns") == 0) {
                                struct ev_loop *ev_base = lua_check_ev_base(L, 3);
 
                                if (ev_base) {
@@ -4514,11 +4536,25 @@ lua_config_init_subsystem(lua_State *L)
                                        return luaL_error(L, "no event base specified");
                                }
                        }
-                       else if (strcmp(parts[i], "symcache") == 0) {
+                       else if (strcmp(str, "symcache") == 0) {
                                rspamd_symcache_init(cfg->cache);
                        }
+                       else if (strcmp(str, "tokenizers") == 0 || strcmp(str, "custom_tokenizers") == 0) {
+                               GError *err = NULL;
+                               if (!rspamd_config_load_custom_tokenizers(cfg, &err)) {
+                                       g_strfreev(parts);
+                                       if (err) {
+                                               int ret = luaL_error(L, "failed to load custom tokenizers: %s", err->message);
+                                               g_error_free(err);
+                                               return ret;
+                                       }
+                                       else {
+                                               return luaL_error(L, "failed to load custom tokenizers");
+                                       }
+                               }
+                       }
                        else {
-                               int ret = luaL_error(L, "invalid param: %s", parts[i]);
+                               int ret = luaL_error(L, "invalid param: %s", str);
                                g_strfreev(parts);
 
                                return ret;
@@ -4772,3 +4808,43 @@ void lua_call_finish_script(struct rspamd_config_cfg_lua_script *sc,
 
        lua_thread_call(thread, 1);
 }
+
+static int
+lua_config_load_custom_tokenizers(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_config *cfg = lua_check_config(L, 1);
+
+       if (cfg != NULL) {
+               GError *err = NULL;
+               gboolean ret = rspamd_config_load_custom_tokenizers(cfg, &err);
+
+               if (!ret && err) {
+                       lua_pushboolean(L, FALSE);
+                       lua_pushstring(L, err->message);
+                       g_error_free(err);
+                       return 2;
+               }
+
+               lua_pushboolean(L, ret);
+               return 1;
+       }
+       else {
+               return luaL_error(L, "invalid arguments");
+       }
+}
+
+static int
+lua_config_unload_custom_tokenizers(lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_config *cfg = lua_check_config(L, 1);
+
+       if (cfg != NULL) {
+               rspamd_config_unload_custom_tokenizers(cfg);
+               return 0;
+       }
+       else {
+               return luaL_error(L, "invalid arguments");
+       }
+}