]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add tests for LLM provider, fix various issues with metatokens
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 15:32:16 +0000 (16:32 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 28 Aug 2025 15:32:16 +0000 (16:32 +0100)
lualib/lua_cache.lua
lualib/lua_meta.lua
lualib/plugins/neural.lua
lualib/plugins/neural/providers/llm.lua
rules/controller/neural.lua
src/client/rspamc.cxx
src/plugins/lua/neural.lua
test/functional/cases/335_neural_llm/003_llm_train.robot [new file with mode: 0644]
test/functional/configs/neural_llm.conf [new file with mode: 0644]
test/functional/util/dummy_llm.py [new file with mode: 0644]

index c87a9dc78d89cf8266aa5282fa2332f394ad4c80..5fb1fbbe794eca3ab7e85133bb218bf20aa6b5ca 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 --[[[
 -- @module lua_cache
@@ -82,10 +82,10 @@ local exports = {}
 -- Default options
 local default_opts = {
   cache_prefix = "rspamd_cache",
-  cache_ttl = 3600, -- 1 hour
-  cache_probes = 5, -- Number of times to check a pending key
-  cache_format = "json", -- Serialization format
-  cache_hash_len = 16, -- Number of hex symbols to use for hashed keys
+  cache_ttl = 3600,         -- 1 hour
+  cache_probes = 5,         -- Number of times to check a pending key
+  cache_format = "json",    -- Serialization format
+  cache_hash_len = 16,      -- Number of hex symbols to use for hashed keys
   cache_use_hashing = false -- Whether to hash keys by default
 }
 
@@ -110,8 +110,9 @@ local function get_cache_key(raw_key, cache_context, force_hashing)
   end
 
   if should_hash then
-    lua_util.debugm(N, rspamd_config, "hashing key '%s' with hash length %s",
-        raw_key, cache_context.opts.cache_hash_len)
+    local raw_len = (type(raw_key) == 'string') and #raw_key or -1
+    lua_util.debugm(N, rspamd_config, "hashing cache key (len=%s) with hash length %s",
+      raw_len, cache_context.opts.cache_hash_len)
     return hash_key(raw_key, cache_context.opts.cache_hash_len)
   else
     return raw_key
@@ -133,8 +134,8 @@ local function create_cache_context(redis_params, opts, module_name)
 
   -- Register Redis prefix
   lua_redis.register_prefix(cache_context.opts.cache_prefix,
-      "caching",
-      "Cache API prefix")
+    "caching",
+    "Cache API prefix")
 
   lua_util.debugm(N, rspamd_config, "registered redis prefix: %s", cache_context.opts.cache_prefix)
 
@@ -233,7 +234,7 @@ local function create_pending_marker(timeout, cache_context)
   }
 
   lua_util.debugm(cache_context.N, rspamd_config, "creating PENDING marker for host %s, timeout %s",
-      hostname, timeout)
+    hostname, timeout)
 
   return "PENDING:" .. encode_data(pending_data, cache_context)
 end
@@ -245,8 +246,8 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c
     return false
   end
 
-  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
-  lua_util.debugm(cache_context.N, task, "cache lookup for key: %s (%s)", key, full_key)
+  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
+  lua_util.debugm(cache_context.N, task, "cache lookup for key: %s", full_key)
 
   -- Function to check a pending key
   local function check_pending(pending_info)
@@ -254,13 +255,13 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c
     local probe_interval = timeout / (cache_context.opts.cache_probes or 5)
 
     lua_util.debugm(cache_context.N, task, "setting up probes for pending key %s, interval: %s seconds",
-        full_key, probe_interval)
+      full_key, probe_interval)
 
     -- Set up a timer to probe the key
     local function probe_key()
       probe_count = probe_count + 1
       lua_util.debugm(cache_context.N, task, "probe #%s/%s for pending key %s",
-          probe_count, cache_context.opts.cache_probes, full_key)
+        probe_count, cache_context.opts.cache_probes, full_key)
 
       if probe_count >= cache_context.opts.cache_probes then
         logger.infox(task, "maximum probes reached for key %s, considering it failed", full_key)
@@ -271,102 +272,102 @@ local function cache_get(task, key, cache_context, timeout, callback_uncached, c
 
       lua_util.debugm(cache_context.N, task, "probing redis for key %s", full_key)
       lua_redis.redis_make_request(task, cache_context.redis_params, key, false,
-          function(err, data)
-            if err then
-              logger.errx(task, "redis error while probing key %s: %s", full_key, err)
-              lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err)
-              task:add_timer(probe_interval, probe_key)
-              return
-            end
+        function(err, data)
+          if err then
+            logger.errx(task, "redis error while probing key %s: %s", full_key, err)
+            lua_util.debugm(cache_context.N, task, "redis error during probe: %s, retrying later", err)
+            task:add_timer(probe_interval, probe_key)
+            return
+          end
 
-            if not data or type(data) == 'userdata' then
-              lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key)
-              callback_uncached(task)
-              return
-            end
+          if not data or type(data) == 'userdata' then
+            lua_util.debugm(cache_context.N, task, "pending key %s disappeared, calling uncached handler", full_key)
+            callback_uncached(task)
+            return
+          end
 
-            local pending = parse_pending_value(data, cache_context)
-            if pending then
-              lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later",
-                  full_key, pending.hostname)
-              task:add_timer(probe_interval, probe_key)
-            else
-              lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key)
-              callback_data(task, nil, decode_data(data, cache_context))
-            end
-          end,
-          'GET', { full_key }
+          local pending = parse_pending_value(data, cache_context)
+          if pending then
+            lua_util.debugm(cache_context.N, task, "key %s still pending (host: %s), retrying later",
+              full_key, pending.hostname)
+            task:add_timer(probe_interval, probe_key)
+          else
+            lua_util.debugm(cache_context.N, task, "pending key %s resolved to actual data", full_key)
+            callback_data(task, nil, decode_data(data, cache_context))
+          end
+        end,
+        'GET', { full_key }
       )
     end
 
     -- Start the first probe after the initial probe interval
     lua_util.debugm(cache_context.N, task, "scheduling first probe for %s in %s seconds",
-        full_key, probe_interval)
+      full_key, probe_interval)
     task:add_timer(probe_interval, probe_key)
   end
 
   -- Initial cache lookup
   lua_util.debugm(cache_context.N, task, "making initial redis GET request for key: %s", full_key)
   lua_redis.redis_make_request(task, cache_context.redis_params, key, false,
-      function(err, data)
-        if err then
-          logger.errx(task, "redis error looking up key %s: %s", full_key, err)
-          lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err)
-          callback_uncached(task)
-          return
-        end
-
-        if not data or type(data) == 'userdata' then
-          -- Key not found, set pending and call the uncached callback
-          lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key)
-          local pending_marker = create_pending_marker(timeout, cache_context)
+    function(err, data)
+      if err then
+        logger.errx(task, "redis error looking up key %s: %s", full_key, err)
+        lua_util.debugm(cache_context.N, task, "redis error: %s, calling uncached handler", err)
+        callback_uncached(task)
+        return
+      end
 
-          lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s",
-              full_key, timeout * 2)
+      if not data or type(data) == 'userdata' then
+        -- Key not found, set pending and call the uncached callback
+        lua_util.debugm(cache_context.N, task, "key %s not found in cache, creating pending marker", full_key)
+        local pending_marker = create_pending_marker(timeout, cache_context)
+
+        lua_util.debugm(cache_context.N, task, "setting pending marker for key %s with TTL %s",
+          full_key, timeout * 2)
+        lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
+          function(set_err, set_data)
+            if set_err then
+              logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err)
+              lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err)
+            else
+              lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key)
+            end
+            lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key)
+            callback_uncached(task)
+          end,
+          'SETEX', { full_key, tostring(timeout * 2), pending_marker }
+        )
+      else
+        -- Key found, check if it's a pending marker or actual data
+        local pending = parse_pending_value(data, cache_context)
+
+        if pending then
+          -- Key is being processed by another worker
+          lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result",
+            full_key, pending.hostname)
+          check_pending(pending)
+        else
+          -- Extend TTL and return data
+          lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s",
+            full_key, cache_context.opts.cache_ttl)
           lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
-              function(set_err, set_data)
-                if set_err then
-                  logger.errx(task, "redis error setting pending marker for %s: %s", full_key, set_err)
-                  lua_util.debugm(cache_context.N, task, "failed to set pending marker: %s", set_err)
-                else
-                  lua_util.debugm(cache_context.N, task, "successfully set pending marker for %s", full_key)
-                end
-                lua_util.debugm(cache_context.N, task, "calling uncached handler for %s", full_key)
-                callback_uncached(task)
-              end,
-              'SETEX', { full_key, tostring(timeout * 2), pending_marker }
+            function(expire_err, _)
+              if expire_err then
+                logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err)
+                lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err)
+              else
+                lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key)
+              end
+            end,
+            'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) }
           )
-        else
-          -- Key found, check if it's a pending marker or actual data
-          local pending = parse_pending_value(data, cache_context)
 
-          if pending then
-            -- Key is being processed by another worker
-            lua_util.debugm(cache_context.N, task, "key %s is pending on host %s, waiting for result",
-                full_key, pending.hostname)
-            check_pending(pending)
-          else
-            -- Extend TTL and return data
-            lua_util.debugm(cache_context.N, task, "found cached data for key %s, extending TTL to %s",
-                full_key, cache_context.opts.cache_ttl)
-            lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
-                function(expire_err, _)
-                  if expire_err then
-                    logger.errx(task, "redis error extending TTL for %s: %s", full_key, expire_err)
-                    lua_util.debugm(cache_context.N, task, "failed to extend TTL: %s", expire_err)
-                  else
-                    lua_util.debugm(cache_context.N, task, "successfully extended TTL for %s", full_key)
-                  end
-                end,
-                'EXPIRE', { full_key, tostring(cache_context.opts.cache_ttl) }
-            )
-
-            lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key)
-            callback_data(task, nil, decode_data(data, cache_context))
-          end
+          lua_util.debugm(cache_context.N, task, "returning cached data for key %s", full_key)
+          callback_data(task, nil, decode_data(data, cache_context))
         end
-      end,
-      'GET', { full_key }
+      end
+    end,
+    'GET', { full_key }
   )
 
   return true
@@ -379,24 +380,24 @@ local function cache_set(task, key, data, cache_context)
     return false
   end
 
-  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
-  lua_util.debugm(cache_context.N, task, "caching data for key: %s (%s) with TTL: %s",
-      full_key, key, cache_context.opts.cache_ttl)
+  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
+  lua_util.debugm(cache_context.N, task, "caching data for key: %s with TTL: %s",
+    full_key, cache_context.opts.cache_ttl)
 
   local encoded_data = encode_data(data, cache_context)
 
   -- Store the data with expiration
   lua_util.debugm(cache_context.N, task, "making redis SETEX request for key: %s", full_key)
   return lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
-      function(err, result)
-        if err then
-          logger.errx(task, "redis error setting cached data for %s: %s", full_key, err)
-          lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err)
-        else
-          lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key)
-        end
-      end,
-      'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data }
+    function(err, result)
+      if err then
+        logger.errx(task, "redis error setting cached data for %s: %s", full_key, err)
+        lua_util.debugm(cache_context.N, task, "failed to cache data: %s", err)
+      else
+        lua_util.debugm(cache_context.N, task, "successfully cached data for key %s", full_key)
+      end
+    end,
+    'SETEX', { full_key, tostring(cache_context.opts.cache_ttl), encoded_data }
   )
 end
 
@@ -407,21 +408,21 @@ local function cache_del(task, key, cache_context)
     return false
   end
 
-  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, false)
+  local full_key = cache_context.opts.cache_prefix .. "_" .. get_cache_key(key, cache_context, nil)
   lua_util.debugm(cache_context.N, task, "deleting cache key: %s", full_key)
 
   return lua_redis.redis_make_request(task, cache_context.redis_params, key, true,
-      function(err, result)
-        if err then
-          logger.errx(task, "redis error deleting cache key %s: %s", full_key, err)
-          lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err)
-        else
-          local count = tonumber(result) or 0
-          lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)",
-              full_key, count)
-        end
-      end,
-      'DEL', { full_key }
+    function(err, result)
+      if err then
+        logger.errx(task, "redis error deleting cache key %s: %s", full_key, err)
+        lua_util.debugm(cache_context.N, task, "failed to delete cache key: %s", err)
+      else
+        local count = tonumber(result) or 0
+        lua_util.debugm(cache_context.N, task, "successfully deleted cache key %s (%s keys removed)",
+          full_key, count)
+      end
+    end,
+    'DEL', { full_key }
   )
 end
 
index 340d89ee81d4f10185eb428c2d8569dcd1bd956d..de006df8e774c9bee069771091d23c57fd914784 100644 (file)
@@ -12,7 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
-]]--
+]] --
 
 local exports = {}
 
@@ -87,7 +87,7 @@ local function meta_images_function(task)
     nlarge = 1.0 * nlarge / ntotal
     nsmall = 1.0 * nsmall / ntotal
   end
-  return { ntotal, njpg, npng, nlarge, nsmall }
+  return { ntotal, npng, njpg, nlarge, nsmall } -- Fixed order to match names
 end
 
 local function meta_nparts_function(task)
@@ -164,29 +164,28 @@ local function meta_received_function(task)
   local fun = require "fun"
 
   if rh and #rh > 0 then
-
     local ntotal = 0.0
     local init_time = 0
 
     fun.each(function(rc)
-      ntotal = ntotal + 1.0
+        ntotal = ntotal + 1.0
 
-      if not rc.by_hostname then
-        invalid_factor = invalid_factor + 1.0
-      end
-      if init_time == 0 and rc.timestamp then
-        init_time = rc.timestamp
-      elseif rc.timestamp then
-        time_factor = time_factor + math.abs(init_time - rc.timestamp)
-        init_time = rc.timestamp
-      end
-      if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then
-        secure_factor = secure_factor + 1.0
-      end
-    end,
-        fun.filter(function(rc)
-          return not rc.flags or not rc.flags['artificial']
-        end, rh))
+        if not rc.by_hostname then
+          invalid_factor = invalid_factor + 1.0
+        end
+        if init_time == 0 and rc.timestamp then
+          init_time = rc.timestamp
+        elseif rc.timestamp then
+          time_factor = time_factor + math.abs(init_time - rc.timestamp)
+          init_time = rc.timestamp
+        end
+        if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then
+          secure_factor = secure_factor + 1.0
+        end
+      end,
+      fun.filter(function(rc)
+        return not rc.flags or not rc.flags['artificial']
+      end, rh))
 
     if ntotal > 0 then
       invalid_factor = invalid_factor / ntotal
@@ -263,8 +262,8 @@ local function meta_words_function(task)
   end
 
   local ret = {
-    short_words,
-    ret_len,
+    ret_len,     -- avg_words_len (moved to match the names array)
+    short_words, -- nshort_words
   }
 
   local divisor = 1.0
@@ -460,10 +459,10 @@ local function rspamd_gen_metatokens(task, names)
         local ct = mt.cb(task)
         for i, tok in ipairs(ct) do
           lua_util.debugm(N, task, "metatoken: %s = %s",
-              mt.names[i], tok)
+            mt.names[i], tok)
           if tok ~= tok or tok == math.huge then
             logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
-                mt.names[i], tok)
+              mt.names[i], tok)
             tok = 0.0
           end
           table.insert(metatokens, tok)
@@ -472,14 +471,13 @@ local function rspamd_gen_metatokens(task, names)
 
       task:cache_set('metatokens', metatokens)
     end
-
   else
     for _, n in ipairs(names) do
       if metatokens_by_name[n] then
         local tok = metatokens_by_name[n](task)
         if tok ~= tok or tok == math.huge then
           logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
-              n, tok)
+            n, tok)
           tok = 0.0
         end
         table.insert(metatokens, tok)
@@ -503,7 +501,7 @@ local function rspamd_gen_metatokens_table(task)
     for i, tok in ipairs(ct) do
       if tok ~= tok or tok == math.huge then
         logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
-            mt.names[i], tok)
+          mt.names[i], tok)
         tok = 0.0
       end
 
index 17661c1e948a9ae3e3e4492432246f2eff9c1cc8..5fcb75fcf91c2ae51b8f5c829a533dadb87d9c67 100644 (file)
@@ -139,6 +139,28 @@ register_provider('symbols', {
   end,
 })
 
+-- Metatokens-only provider for contexts where symbols are not available
+register_provider('metatokens', {
+  collect = function(task, ctx)
+    local mt = meta_functions.rspamd_gen_metatokens(task)
+    -- Convert to table of numbers
+    local vec = {}
+    for i = 1, #mt do
+      vec[i] = tonumber(mt[i]) or 0.0
+    end
+    return vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 }
+  end,
+  collect_async = function(task, ctx, cont)
+    local mt = meta_functions.rspamd_gen_metatokens(task)
+    -- Convert to table of numbers
+    local vec = {}
+    for i = 1, #mt do
+      vec[i] = tonumber(mt[i]) or 0.0
+    end
+    cont(vec, { name = 'metatokens', type = 'metatokens', dim = #vec, weight = ctx.weight or 1.0 })
+  end,
+})
+
 local function load_scripts()
   redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
     redis_params)
@@ -546,6 +568,7 @@ end
 
 local function redis_ann_prefix(rule, settings_name)
   -- We also need to count metatokens:
+  -- Note: meta_functions.version represents the metatoken format version
   local n = meta_functions.version
   return string.format('%s%d_%s_%d_%s',
     settings.prefix, plugin_ver, rule.prefix, n, settings_name)
@@ -669,6 +692,7 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
     end
     prov.collect_async(task, {
       profile = profile_or_set,
+      set = profile_or_set,
       rule = rule,
       config = pcfg,
       weight = pcfg.weight or 1.0,
@@ -682,25 +706,49 @@ local function collect_features_async(task, rule, profile_or_set, phase, cb)
     end)
   end
 
-  -- Include metatokens as an extra provider when configured
-  local include_meta = rule.fusion and rule.fusion.include_meta
+  -- Include symbols provider (which includes both symbols AND metatokens) as an extra provider
+  -- The name 'include_meta' is historical but it actually includes the full symbols provider
+  -- For backward compatibility, include symbols by default unless explicitly disabled
+  local include_meta = false
+  if not providers_cfg or #providers_cfg == 0 then
+    -- No providers, always use symbols (which includes metatokens)
+    include_meta = true
+  elseif rule.fusion then
+    -- Explicit fusion config takes precedence
+    include_meta = rule.fusion.include_meta
+    if include_meta == nil then
+      -- Default to true for backward compatibility when fusion is configured but include_meta not specified
+      include_meta = true
+    end
+  else
+    -- Providers configured but no fusion settings - default to including symbols+metatokens
+    include_meta = true
+  end
+
   local meta_weight = (rule.fusion and rule.fusion.meta_weight) or 1.0
 
   remaining = #providers_cfg + (include_meta and 1 or 0)
+
+  -- Start all configured providers
   for i, pcfg in ipairs(providers_cfg) do
     start_provider(i, pcfg)
   end
 
   if include_meta then
-    local prov = get_provider('symbols')
+    -- Always use metatokens provider for consistency
+    -- This ensures same dimensions whether called from controller or full scan
+    local prov = get_provider('metatokens')
+
     if prov and prov.collect_async then
-      prov.collect_async(task, { profile = profile_or_set, weight = meta_weight, phase = phase }, function(vec, meta)
-        if vec then
-          metas[#metas + 1] = { name = 'symbols', type = 'symbols', dim = #vec, weight = meta_weight }
-          vectors[#vectors + 1] = vec
-        end
-        maybe_finish()
-      end)
+      local meta_index = #providers_cfg + 1 -- Metatokens always come after providers
+      prov.collect_async(task, { profile = profile_or_set, set = profile_or_set, weight = meta_weight, phase = phase },
+        function(vec, meta)
+          if vec then
+            metas[meta_index] = meta
+            vectors[meta_index] = vec
+          end
+          maybe_finish()
+        end)
     else
       maybe_finish()
     end
@@ -711,8 +759,24 @@ end
 local function spawn_train(params)
   -- Check training data sanity
   -- Now we need to join inputs and create the appropriate test vectors
-  local n = #params.set.symbols +
-      meta_functions.rspamd_count_metatokens()
+  local n
+
+  -- When using providers, derive dimension from actual vectors
+  if params.rule.providers and #params.rule.providers > 0 and
+      (#params.spam_vec > 0 or #params.ham_vec > 0) then
+    -- Use dimension from stored vectors
+    if #params.spam_vec > 0 then
+      n = #params.spam_vec[1]
+    else
+      n = #params.ham_vec[1]
+    end
+    lua_util.debugm(N, rspamd_config, 'spawn_train: using vector dimension %s from stored vectors', n)
+  else
+    -- Traditional symbol-based dimension
+    n = #params.set.symbols + meta_functions.rspamd_count_metatokens()
+    lua_util.debugm(N, rspamd_config, 'spawn_train: using symbol dimension %s symbols + %s metatokens = %s',
+      #params.set.symbols, meta_functions.rspamd_count_metatokens(), n)
+  end
 
   -- Now we can train ann
   local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
@@ -1148,7 +1212,7 @@ result_to_vector = function(task, profile)
   if not profile.zeros then
     -- Fill zeros vector
     local zeros = {}
-    for i = 1, meta_functions.count_metatokens() do
+    for i = 1, meta_functions.rspamd_count_metatokens() do
       zeros[i] = 0.0
     end
     for _, _ in ipairs(profile.symbols) do
index 8f08fbb57b9b58d416558dd25a7fe7d1071685db..1bc1063aae8c4868f22ab89d05edbc1755a9f911 100644 (file)
@@ -74,6 +74,16 @@ neural_common.register_provider('llm', {
       return
     end
 
+    -- Do not run embeddings on infer if ANN is not loaded for this set/profile
+    if ctx.phase == 'infer' then
+      local set_or_profile = ctx.profile or ctx.set
+      if not set_or_profile or not set_or_profile.ann then
+        rspamd_logger.debugm(N, task, 'skip llm on infer: ANN not loaded for current settings')
+        cont(nil)
+        return
+      end
+    end
+
     local input_tbl = select_text(task)
     if not input_tbl then
       rspamd_logger.debugm(N, task, 'llm provider has no content to embed; skip')
index 0aace1cc1d2af2df0bb158b539edff6437e1d17c..13530beffe55da0a170e980a41b6f55a670a61cd 100644 (file)
@@ -195,22 +195,37 @@ local function handle_learn_message(task, conn)
     return
   end
 
-  -- If no providers or symbols provider configured, require full scan path
+  -- Check if this configuration requires full scan
+  -- Only symbols collection requires full scan; metatokens can be computed directly
   local has_providers = type(rule.providers) == 'table' and #rule.providers > 0
-  if not has_providers then
-    lua_util.debugm(N, task, 'controller.neural: learn_message refused: no providers (assume symbols) for rule=%s',
+
+  if not has_providers and not rule.disable_symbols_input then
+    -- No providers means full symbols will be used (not just metatokens)
+    lua_util.debugm(N, task,
+      'controller.neural: learn_message refused: no providers configured, symbols collection requires full scan for rule=%s',
       rule_name)
-    conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured)')
+    conn:send_error(400, 'rule requires full /checkv2 scan (no providers configured, full symbols collection required)')
     return
   end
-  for _, p in ipairs(rule.providers) do
-    if p.type == 'symbols' then
-      lua_util.debugm(N, task, 'controller.neural: learn_message refused due to symbols provider for rule=%s', rule_name)
-      conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
-      return
+
+  -- Check if any provider requires full scan (only symbols provider does)
+  if has_providers then
+    for _, p in ipairs(rule.providers) do
+      if p.type == 'symbols' then
+        lua_util.debugm(N, task,
+          'controller.neural: learn_message refused due to symbols provider requiring full scan for rule=%s',
+          rule_name)
+        conn:send_error(400, 'rule requires full /checkv2 scan (symbols provider present)')
+        return
+      end
     end
   end
 
+  -- At this point:
+  -- - We have providers that don't require full scan (e.g., LLM)
+  -- - Metatokens can be computed directly from the message
+  -- - Controller training is allowed
+
   local set = neural_common.get_rule_settings(task, rule)
   if not set then
     lua_util.debugm(N, task, 'controller.neural: no settings resolved for rule=%s; falling back to first available set',
@@ -224,6 +239,11 @@ local function handle_learn_message(task, conn)
     end
   end
 
+  if set then
+    lua_util.debugm(N, task, 'controller.neural: set found for rule=%s, symbols=%s, name=%s',
+      rule_name, set.symbols and #set.symbols or "nil", set.name)
+  end
+
   -- Derive redis base key even if ANN not yet initialized
   local redis_base
   if set and set.ann and set.ann.redis_key then
@@ -244,17 +264,55 @@ local function handle_learn_message(task, conn)
     return
   end
 
+  -- Ensure profile exists for this set
+  if not set.ann then
+    local version = 0
+    local ann_key = neural_common.new_ann_key(rule, set, version)
+
+    local profile = {
+      symbols = set.symbols,
+      redis_key = ann_key,
+      version = version,
+      digest = set.digest,
+      distance = 0,
+      providers_digest = neural_common.providers_config_digest(rule.providers),
+    }
+
+    local ucl = require "ucl"
+    local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+    lua_util.debugm(N, task, 'controller.neural: creating new profile for %s:%s at %s',
+      rule.prefix, set.name, ann_key)
+
+    -- Store the profile in Redis sorted set
+    lua_redis.redis_make_request(task,
+      rule.redis,
+      nil,
+      true, -- is write
+      function(err, _)
+        if err then
+          rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
+            rule.prefix, set.name, profile.redis_key, err)
+        else
+          lua_util.debugm(N, task, 'created new ANN profile for %s:%s, data stored at prefix %s',
+            rule.prefix, set.name, profile.redis_key)
+        end
+      end,
+      'ZADD', -- command
+      { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+    )
+
+    -- Update redis_base to use the new ann_key
+    redis_base = ann_key
+  end
+
   local function after_collect(vec)
     lua_util.debugm(N, task, 'controller.neural: learn_message after_collect, vector=%s', type(vec))
     if not vec then
-      if rule.providers and #rule.providers > 0 then
-        lua_util.debugm(N, task,
-          'controller.neural: no vector from providers; skip training to keep dimensions consistent')
-        conn:send_error(400, 'no vector collected from providers')
-        return
-      else
-        vec = neural_common.result_to_vector(task, set)
-      end
+      lua_util.debugm(N, task,
+        'controller.neural: no vector collected; skip training')
+      conn:send_error(400, 'no vector collected')
+      return
     end
 
     if type(vec) ~= 'table' then
index e2128f357aecb416f5bd4cf2ce73f21cdfca112f..c42d301429f59ee0b9b3683f1d8315dc795a975c 100644 (file)
@@ -212,6 +212,61 @@ static void rspamc_counters_output(FILE *out, ucl_object_t *obj);
 
 static void rspamc_stat_output(FILE *out, ucl_object_t *obj);
 
+static void
+rspamc_neural_learn_output(FILE *out, ucl_object_t *obj)
+{
+       bool is_success = true;
+       const char *filename = nullptr;
+       double scan_time = -1.0;
+       const char *redis_key = nullptr;
+       std::uintmax_t stored_bytes = 0;
+       bool have_stored = false;
+
+       if (obj != nullptr) {
+               const auto *ok = ucl_object_lookup(obj, "success");
+               if (ok) {
+                       is_success = ucl_object_toboolean(ok);
+               }
+               const auto *fn = ucl_object_lookup(obj, "filename");
+               if (fn) {
+                       filename = ucl_object_tostring(fn);
+               }
+               const auto *st = ucl_object_lookup(obj, "scan_time");
+               if (st) {
+                       scan_time = ucl_object_todouble(st);
+               }
+               const auto *rb = ucl_object_lookup(obj, "stored");
+               if (rb) {
+                       stored_bytes = (std::uintmax_t) ucl_object_toint(rb);
+                       have_stored = true;
+               }
+               const auto *rk = ucl_object_lookup(obj, "key");
+               if (rk) {
+                       redis_key = ucl_object_tostring(rk);
+               }
+       }
+
+       // First line: success
+       fprintf(out, "success = %s;\n", is_success ? "true" : "false");
+
+       // Then other fields in k = v; format
+       if (filename) {
+               fprintf(out, "filename = \"%s\";\n", filename);
+       }
+       if (scan_time >= 0) {
+               fprintf(out, "scan_time = %.6f;\n", scan_time);
+       }
+       if (!neural_train.empty()) {
+               fprintf(out, "class = \"%s\";\n", neural_train.c_str());
+       }
+       if (have_stored) {
+               fprintf(out, "stored = %ju bytes;\n", stored_bytes);
+       }
+       if (redis_key) {
+               fprintf(out, "key = \"%s\";\n", redis_key);
+       }
+}
+
 enum rspamc_command_type {
        RSPAMC_COMMAND_UNKNOWN = 0,
        RSPAMC_COMMAND_CHECK,
@@ -288,7 +343,7 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
                .is_controller = FALSE,
                .is_privileged = FALSE,
                .need_input = TRUE,
-               .command_output_func = rspamc_symbols_output},
+               .command_output_func = rspamc_neural_learn_output},
        rspamc_command{
                .cmd = RSPAMC_COMMAND_FUZZY_ADD,
                .name = "fuzzy_add",
index 1e8a135f18a544fa6da2f3e13cd056e44466d2b2..3f0a3c7aa7737c7b575fd1d0a2bc901df8601b01 100644 (file)
@@ -69,20 +69,20 @@ local function new_ann_profile(task, rule, set, version)
   local function add_cb(err, _)
     if err then
       rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
-          rule.prefix, set.name, profile.redis_key, err)
+        rule.prefix, set.name, profile.redis_key, err)
     else
       rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
-          rule.prefix, set.name, profile.redis_key)
+        rule.prefix, set.name, profile.redis_key)
     end
   end
 
   lua_redis.redis_make_request(task,
-      rule.redis,
-      nil,
-      true, -- is write
-      add_cb, --callback
-      'ZADD', -- command
-      { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+    rule.redis,
+    nil,
+    true,   -- is write
+    add_cb, --callback
+    'ZADD', -- command
+    { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
   )
 
   return profile
@@ -103,18 +103,18 @@ local function ann_scores_filter(task)
         profile = set.ann
       else
         lua_util.debugm(N, task, 'no ann loaded for %s:%s',
-            rule.prefix, set.name)
+          rule.prefix, set.name)
       end
     else
       lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
-          rule.prefix, sid)
+        rule.prefix, sid)
     end
 
     if ann then
       local function after_features(vec, meta)
         if profile.providers_digest and meta and meta.digest and profile.providers_digest ~= meta.digest then
           lua_util.debugm(N, task, 'providers digest mismatch for %s:%s, skip ANN apply',
-              rule.prefix, set.name)
+            rule.prefix, set.name)
           vec = nil
         end
 
@@ -131,7 +131,7 @@ local function ann_scores_filter(task)
         local symscore = string.format('%.3f', score)
         task:cache_set(rule.prefix .. '_neural_score', score)
         lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
-            rule.prefix, set.name, set.ann.version, symscore)
+          rule.prefix, set.name, set.ann.version, symscore)
 
         if score > 0 then
           local result = score
@@ -152,8 +152,8 @@ local function ann_scores_filter(task)
             end
           else
             lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
-                rule.prefix, set.name, set.ann.version, symscore,
-                spam_threshold)
+              rule.prefix, set.name, set.ann.version, symscore,
+              spam_threshold)
           end
         else
           local result = -(score)
@@ -174,8 +174,8 @@ local function ann_scores_filter(task)
             end
           else
             lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
-                rule.prefix, set.name, set.ann.version, result,
-                ham_threshold)
+              rule.prefix, set.name, set.ann.version, result,
+              ham_threshold)
           end
         end
       end
@@ -214,20 +214,33 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     end
   end
 
+  -- If LLM provider is configured, suppress autotrain unless manual training requested
+  if not manual_train and rule.providers and #rule.providers > 0 then
+    for _, p in ipairs(rule.providers) do
+      if p.type == 'llm' then
+        lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header')
+        learn_spam = false
+        learn_ham = false
+        skip_reason = 'llm provider requires manual training'
+        break
+      end
+    end
+  end
+
   if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
     if train_opts.spam_score then
       learn_spam = score >= train_opts.spam_score
 
       if not learn_spam then
         skip_reason = string.format('score < spam_score: %f < %f',
-            score, train_opts.spam_score)
+          score, train_opts.spam_score)
       end
     else
       learn_spam = verdict == 'spam' or verdict == 'junk'
 
       if not learn_spam then
         skip_reason = string.format('verdict: %s',
-            verdict)
+          verdict)
       end
     end
 
@@ -235,14 +248,14 @@ local function ann_push_task_result(rule, task, verdict, score, set)
       learn_ham = score <= train_opts.ham_score
       if not learn_ham then
         skip_reason = string.format('score > ham_score: %f > %f',
-            score, train_opts.ham_score)
+          score, train_opts.ham_score)
       end
     else
       learn_ham = verdict == 'ham'
 
       if not learn_ham then
         skip_reason = string.format('verdict: %s',
-            verdict)
+          verdict)
       end
     end
   else
@@ -281,56 +294,63 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         local nspam, nham = data[1], data[2]
 
         if manual_train or neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
-          local vec
-          if rule.providers and #rule.providers > 0 then
-            -- Note: this training path remains sync for now; vectors are pushed when computed
-            -- fall back to legacy vector; async training push will be added later
-            vec = neural_common.result_to_vector(task, set)
-          else
-            vec = neural_common.result_to_vector(task, set)
-          end
+          local function store_train_vec(vec)
+            if not vec then
+              lua_util.debugm(N, task, "no vector collected for training")
+              return
+            end
 
-          local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
-          local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+            local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+            local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
 
-          local function learn_vec_cb(redis_err)
-            if redis_err then
-              rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+            local function learn_vec_cb(redis_err)
+              if redis_err then
+                rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
                   rule.prefix, set.name, redis_err)
-            else
-              lua_util.debugm(N, task,
+              else
+                lua_util.debugm(N, task,
                   "add train data for ANN rule " ..
-                      "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+                  "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
                   rule.prefix, set.name, learn_type, #vec, target_key, #str)
+              end
             end
-          end
 
-          lua_redis.redis_make_request(task,
+            lua_redis.redis_make_request(task,
               rule.redis,
               nil,
-              true, -- is write
-              learn_vec_cb, --callback
-              'SADD', -- command
+              true,               -- is write
+              learn_vec_cb,       --callback
+              'SADD',             -- command
               { target_key, str } -- arguments
-          )
+            )
+          end
+
+          if rule.providers and #rule.providers > 0 then
+            -- Use async feature collection with providers, same as inference
+            neural_common.collect_features_async(task, rule, set, 'train', store_train_vec)
+          else
+            -- Traditional symbol-based vector
+            local vec = neural_common.result_to_vector(task, set)
+            store_train_vec(vec)
+          end
         else
           lua_util.debugm(N, task,
-              "do not add %s train data for ANN rule " ..
-                  "%s:%s",
-              learn_type, rule.prefix, set.name)
+            "do not add %s train data for ANN rule " ..
+            "%s:%s",
+            learn_type, rule.prefix, set.name)
         end
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
-              rule.prefix, set.name, err)
+            rule.prefix, set.name, err)
         elseif type(data) == 'string' then
           -- nil return value
           rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
-              learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+            learn_type, rule.prefix, set.name, set.ann.redis_key, data)
         else
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
-              'please remove this key from Redis manually if you perform upgrade from the previous version',
-              rule.prefix, set.name, set.ann.redis_key, type(data))
+            'please remove this key from Redis manually if you perform upgrade from the previous version',
+            rule.prefix, set.name, set.ann.redis_key, type(data))
         end
       end
     end
@@ -341,25 +361,25 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         -- Need to create or load a profile corresponding to the current configuration
         set.ann = new_ann_profile(task, rule, set, 0)
         lua_util.debugm(N, task,
-            'requested new profile for %s, set.ann is missing',
-            set.name)
+          'requested new profile for %s, set.ann is missing',
+          set.name)
       end
 
       lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
-          { task = task, is_write = false },
-          vectors_len_cb,
-          {
-            set.ann.redis_key,
-          })
+        { task = task, is_write = false },
+        vectors_len_cb,
+        {
+          set.ann.redis_key,
+        })
     else
       lua_util.debugm(N, task,
-          'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+        'do not push data: train condition not satisfied; reason: not checked existing ANNs')
     end
   else
     lua_util.debugm(N, task,
-        'do not push data to key %s: train condition not satisfied; reason: %s',
-        (set.ann or {}).redis_key,
-        skip_reason)
+      'do not push data to key %s: train condition not satisfied; reason: %s',
+      (set.ann or {}).redis_key,
+      skip_reason)
   end
 end
 
@@ -380,20 +400,21 @@ end
 local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local spam_elts = {}
   local ham_elts = {}
+  lua_util.debugm(N, rspamd_config, 'do_train_ann: start for %s:%s key=%s', rule.prefix, set.name, ann_key)
 
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
       -- Unlock on error
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-          'HDEL', -- command
-          { ann_key, 'lock' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        true,                                            -- is write
+        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL',                                          -- command
+        { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
@@ -414,29 +435,29 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_spam_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
       -- Unlock ANN on error
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          neural_common.gen_unlock_cb(rule, set, ann_key), --callback
-          'HDEL', -- command
-          { ann_key, 'lock' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        true,                                            -- is write
+        neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL',                                          -- command
+        { ann_key, 'lock' }
       )
     else
       -- Decompress and convert to numbers each training vector
       spam_elts = process_training_vectors(data)
       -- Now get ham vectors...
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_ham_cb, --callback
-          'SMEMBERS', -- command
-          { ann_key .. '_ham_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,        -- is write
+        redis_ham_cb, --callback
+        'SMEMBERS',   -- command
+        { ann_key .. '_ham_set' }
       )
     end
   end
@@ -444,33 +465,33 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   local function redis_lock_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
-          ann_key, err)
+        ann_key, err)
     elseif type(data) == 'number' and data == 1 then
       -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_spam_cb, --callback
-          'SMEMBERS', -- command
-          { ann_key .. '_spam_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,         -- is write
+        redis_spam_cb, --callback
+        'SMEMBERS',    -- command
+        { ann_key .. '_spam_set' }
       )
 
       rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
-          rule.prefix, set.name, ann_key)
+        rule.prefix, set.name, ann_key)
     else
       local lock_tm = tonumber(data[1])
       rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
-          'locked by another host %s at %s', rule.prefix, set.name, ann_key,
-          data[2], os.date('%c', lock_tm))
+        'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+        data[2], os.date('%c', lock_tm))
     end
   end
 
   -- Check if we are already learning this network
   if set.learning_spawned then
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
-        ann_key)
+      ann_key)
     return
   end
 
@@ -478,14 +499,14 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
   -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
   -- ANN is locked by another host (or a process, meh)
   lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
-      { ev_base = ev_base, is_write = true },
-      redis_lock_cb,
-      {
-        ann_key,
-        tostring(os.time()),
-        tostring(math.max(10.0, rule.watch_interval * 2)),
-        rspamd_util.get_hostname()
-      })
+    { ev_base = ev_base, is_write = true },
+    redis_lock_cb,
+    {
+      ann_key,
+      tostring(os.time()),
+      tostring(math.max(10.0, rule.watch_interval * 2)),
+      rspamd_util.get_hostname()
+    })
 end
 
 -- This function loads new ann from Redis
@@ -500,7 +521,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
   local function data_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
-          ann_key, err)
+        ann_key, err)
     else
       if type(data) == 'table' then
         if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
@@ -509,7 +530,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
 
           if _err or not ann_data then
             rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
-                rule.prefix .. ':' .. set.name, ann_key, _err)
+              rule.prefix .. ':' .. set.name, ann_key, _err)
             return
           else
             ann = rspamd_kann.load(ann_data)
@@ -533,26 +554,26 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               end
               -- Also update rank for the loaded ANN to avoid removal
               lua_redis.redis_make_request_taskless(ev_base,
-                  rspamd_config,
-                  rule.redis,
-                  nil,
-                  true, -- is write
-                  rank_cb, --callback
-                  'ZADD', -- command
-                  { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+                rspamd_config,
+                rule.redis,
+                nil,
+                true,    -- is write
+                rank_cb, --callback
+                'ZADD',  -- command
+                { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
               )
               rspamd_logger.infox(rspamd_config,
-                  'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[1], profile.version)
+                'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                rule.prefix, set.name, ann_key, #data[1], profile.version)
             else
               rspamd_logger.errx(rspamd_config,
-                  'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
-                  rule.prefix, set.name, ann_key)
+                'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+                rule.prefix, set.name, ann_key)
             end
           end
         else
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
-              rule.prefix, set.name, ann_key)
+            rule.prefix, set.name, ann_key)
         end
 
         if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
@@ -564,8 +585,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             local roc_thresholds = parser:get_object()
             set.ann.roc_thresholds = roc_thresholds
             rspamd_logger.infox(rspamd_config,
-                'loaded ROC thresholds for %s:%s; version=%s',
-                rule.prefix, set.name, profile.version)
+              'loaded ROC thresholds for %s:%s; version=%s',
+              rule.prefix, set.name, profile.version)
             rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
           end
         end
@@ -578,19 +599,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
               rspamd_logger.infox(rspamd_config,
-                  'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[3], profile.version)
+                'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                rule.prefix, set.name, ann_key, #data[3], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config,
-                  'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
-                  rule.prefix, set.name, ann_key)
+                'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+                rule.prefix, set.name, ann_key)
             end
           else
             -- pca can be missing merely if we have no max_inputs
             if rule.max_inputs then
               rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
-                  rule.prefix, set.name, ann_key, _err)
+                rule.prefix, set.name, ann_key, _err)
               set.ann.ann = nil
             else
               -- It is okay
@@ -619,19 +640,19 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
         end
       else
         lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
-            rule.prefix, set.name, ann_key)
+          rule.prefix, set.name, ann_key)
       end
     end
   end
   lua_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      rule.redis,
-      nil,
-      false, -- is write
-      data_cb, --callback
-      'HMGET', -- command
-      { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
-      { opaque_data = true }
+    rspamd_config,
+    rule.redis,
+    nil,
+    false,                                                                       -- is write
+    data_cb,                                                                     --callback
+    'HMGET',                                                                     -- command
+    { ann_key, 'ann', 'roc_thresholds', 'pca', 'providers_meta', 'norm_stats' }, -- arguments
+    { opaque_data = true }
   )
 end
 
@@ -644,6 +665,8 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
   local my_symbols = set.symbols
   local min_diff = math.huge
   local sel_elt
+  lua_util.debugm(N, rspamd_config, 'process_existing_ann: have %s profiles for %s:%s',
+    type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
 
   for _, elt in fun.iter(profiles) do
     if elt and elt.symbols then
@@ -667,34 +690,34 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
         if set.ann.version < sel_elt.version then
           -- Load new ann
           rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
-              'our version = %s, remote version = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.version,
-              sel_elt.version)
+            'our version = %s, remote version = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.version,
+            sel_elt.version)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
-              'our version = %s, remote version = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.version,
-              sel_elt.version)
+            'our version = %s, remote version = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.version,
+            sel_elt.version)
         end
       else
         -- We have some different ANN, so we need to compare distance
         if set.ann.distance > min_diff then
           -- Load more specific ANN
           rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
-              'our distance = %s, remote distance = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.distance,
-              min_diff)
+            'our distance = %s, remote distance = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.distance,
+            min_diff)
           load_new_ann(rule, ev_base, set, sel_elt, min_diff)
         else
           lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
-              'our distance = %s, remote distance = %s',
-              rule.prefix .. ':' .. set.name,
-              set.ann.distance,
-              min_diff)
+            'our distance = %s, remote distance = %s',
+            rule.prefix .. ':' .. set.name,
+            set.ann.distance,
+            min_diff)
         end
       end
     else
@@ -702,6 +725,12 @@ local function process_existing_ann(_, ev_base, rule, set, profiles)
       load_new_ann(rule, ev_base, set, sel_elt, min_diff)
     end
   end
+  if sel_elt then
+    lua_util.debugm(N, rspamd_config, 'process_existing_ann: selected profile version=%s key=%s', sel_elt.version,
+      sel_elt.redis_key)
+  else
+    lua_util.debugm(N, rspamd_config, 'process_existing_ann: no suitable profile found')
+  end
 end
 
 
@@ -715,6 +744,8 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     spam = 0,
     ham = 0,
   }
+  lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: %s profiles for %s:%s',
+    type(profiles) == 'table' and #profiles or -1, rule.prefix, set.name)
 
   for _, elt in fun.iter(profiles) do
     if elt and elt.symbols then
@@ -732,14 +763,14 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     local ann_key = sel_elt.redis_key
 
     lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
-        ann_key)
+      ann_key)
 
     -- Create continuation closure
     local redis_len_cb_gen = function(cont_cb, what, is_final)
       return function(err, data)
         if err then
           rspamd_logger.errx(rspamd_config,
-              'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+            'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
         elseif data and type(data) == 'number' or type(data) == 'string' then
           local ntrains = tonumber(data) or 0
           lens[what] = ntrains
@@ -760,31 +791,31 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
               end
               if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
                 lua_util.debugm(N, rspamd_config,
-                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                    ann_key, lens, rule.train.max_trains, what)
+                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                  ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                    ann_key, what, lens, rule.train.max_trains)
+                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                  ann_key, what, lens, rule.train.max_trains)
               end
             else
               -- Probabilistic mode, just ensure that at least one vector is okay
               if min_len > 0 and max_len >= rule.train.max_trains then
                 lua_util.debugm(N, rspamd_config,
-                    'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
-                    ann_key, lens, rule.train.max_trains, what)
+                  'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+                  ann_key, lens, rule.train.max_trains, what)
                 cont_cb()
               else
                 lua_util.debugm(N, rspamd_config,
-                    'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
-                    ann_key, what, lens, rule.train.max_trains)
+                  'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+                  ann_key, what, lens, rule.train.max_trains)
               end
             end
           else
             lua_util.debugm(N, rspamd_config,
-                'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
-                what, ann_key, ntrains, rule.train.max_trains)
+              'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+              what, ann_key, ntrains, rule.train.max_trains)
             cont_cb()
           end
         end
@@ -793,32 +824,34 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
 
     local function initiate_train()
       rspamd_logger.infox(rspamd_config,
-          'need to learn ANN %s after %s required learn vectors',
-          ann_key, lens)
+        'need to learn ANN %s after %s required learn vectors',
+        ann_key, lens)
+      lua_util.debugm(N, rspamd_config, 'maybe_train_existing_ann: initiating train for key=%s spam=%s ham=%s', ann_key,
+        lens.spam or -1, lens.ham or -1)
       do_train_ann(worker, ev_base, rule, set, ann_key)
     end
 
     -- Spam vector is OK, check ham vector length
     local function check_ham_len()
       lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_len_cb_gen(initiate_train, 'ham', true), --callback
-          'SCARD', -- command
-          { ann_key .. '_ham_set' }
+        rspamd_config,
+        rule.redis,
+        nil,
+        false,                                         -- is write
+        redis_len_cb_gen(initiate_train, 'ham', true), --callback
+        'SCARD',                                       -- command
+        { ann_key .. '_ham_set' }
       )
     end
 
     lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false, -- is write
-        redis_len_cb_gen(check_ham_len, 'spam', false), --callback
-        'SCARD', -- command
-        { ann_key .. '_spam_set' }
+      rspamd_config,
+      rule.redis,
+      nil,
+      false,                                          -- is write
+      redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+      'SCARD',                                        -- command
+      { ann_key .. '_spam_set' }
     )
   end
 end
@@ -831,7 +864,7 @@ local function load_ann_profile(element)
   local res, ucl_err = parser:parse_string(element)
   if not res then
     rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
-        ucl_err)
+      ucl_err)
     return nil
   else
     local profile = parser:get_object()
@@ -851,13 +884,16 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
     local function members_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
-            err)
+          err)
         set.can_store_vectors = true
       elseif type(data) == 'table' then
-        lua_util.debugm(N, cfg, '%s: process element %s:%s',
-            what, rule.prefix, set.name)
+        lua_util.debugm(N, cfg, '%s: process element %s:%s (profiles=%s)',
+          what, rule.prefix, set.name, #data)
         process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
         set.can_store_vectors = true
+      else
+        lua_util.debugm(N, cfg, '%s: no profiles for %s:%s', what, rule.prefix, set.name)
+        set.can_store_vectors = true
       end
     end
 
@@ -867,13 +903,13 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
       -- Select the most appropriate to our profile but it should not differ by more
       -- than 30% of symbols
       lua_redis.redis_make_request_taskless(ev_base,
-          cfg,
-          rule.redis,
-          nil,
-          false, -- is write
-          members_cb, --callback
-          'ZREVRANGE', -- command
-          { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+        cfg,
+        rule.redis,
+        nil,
+        false,                                               -- is write
+        members_cb,                                          --callback
+        'ZREVRANGE',                                         -- command
+        { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
       )
     end
   end -- Cycle over all settings
@@ -887,23 +923,23 @@ local function cleanup_anns(rule, cfg, ev_base)
     local function invalidate_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
-            err)
+          err)
       elseif type(data) == 'table' then
         for _, expired in ipairs(data) do
           local profile = load_ann_profile(expired)
           rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
-              rule.prefix .. ':' .. set.name,
-              profile.redis_key,
-              profile.version)
+            rule.prefix .. ':' .. set.name,
+            profile.redis_key,
+            profile.version)
         end
       end
     end
 
     if type(set) == 'table' then
       lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
-          { ev_base = ev_base, is_write = true },
-          invalidate_cb,
-          { set.prefix, tostring(settings.max_profiles) })
+        { ev_base = ev_base, is_write = true },
+        invalidate_cb,
+        { set.prefix, tostring(settings.max_profiles) })
     end
   end
 end
@@ -922,14 +958,14 @@ local function ann_push_vector(task)
 
   if verdict == 'passthrough' then
     lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
-        verdict, score)
+      verdict, score)
 
     return
   end
 
   if score ~= score then
     lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
-        verdict)
+      verdict)
 
     return
   end
@@ -999,7 +1035,7 @@ for k, r in pairs(rules) do
 
   if rule_elt.max_inputs and not has_blas then
     rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
-        rule_elt.name, rule_elt.max_inputs)
+      rule_elt.name, rule_elt.max_inputs)
     rule_elt.max_inputs = nil
   end
 
@@ -1011,7 +1047,7 @@ for k, r in pairs(rules) do
       end
       if (pcfg.type == 'llm' or pcfg.name == 'llm') and not (pcfg.model or (rspamd_config:get_all_opt('gpt') or {}).model) then
         rspamd_logger.errx(rspamd_config,
-            'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
+          'llm provider in rule %s requires model; please set providers[i].model or gpt.model', k)
       end
     end
   end
@@ -1062,21 +1098,21 @@ for _, rule in pairs(settings.rules) do
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     if worker:is_scanner() then
       rspamd_config:add_periodic(ev_base, 0.0,
-          function(_, _)
-            return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
-                'try_load_ann')
-          end)
+        function(_, _)
+          return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+            'try_load_ann')
+        end)
     end
 
     if worker:is_primary_controller() then
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,
-          function(_, _)
-            -- Clean old ANNs
-            cleanup_anns(rule, cfg, ev_base)
-            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
-                'try_train_ann')
-          end)
+        function(_, _)
+          -- Clean old ANNs
+          cleanup_anns(rule, cfg, ev_base)
+          return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+            'try_train_ann')
+        end)
     end
   end)
 end
diff --git a/test/functional/cases/335_neural_llm/003_llm_train.robot b/test/functional/cases/335_neural_llm/003_llm_train.robot
new file mode 100644 (file)
index 0000000..aa76a15
--- /dev/null
@@ -0,0 +1,39 @@
+*** Settings ***
+Suite Setup      Rspamd Redis Setup
+Suite Teardown   Rspamd Redis Teardown
+Library         Process
+Library         ${RSPAMD_TESTDIR}/lib/rspamd.py
+Resource        ${RSPAMD_TESTDIR}/lib/rspamd.robot
+Variables       ${RSPAMD_TESTDIR}/lib/vars.py
+
+*** Variables ***
+${CONFIG}          ${RSPAMD_TESTDIR}/configs/neural_llm.conf
+${SPAM_MSG}        ${RSPAMD_TESTDIR}/messages/spam_message.eml
+${HAM_MSG}         ${RSPAMD_TESTDIR}/messages/ham.eml
+${REDIS_SCOPE}     Suite
+${RSPAMD_SCOPE}    Suite
+${RSPAMD_URL_TLD}  ${RSPAMD_TESTDIR}/../lua/unit/test_tld.dat
+
+*** Test Cases ***
+Train LLM-backed neural and verify
+  Run Dummy Llm
+
+  # Learn spam
+  ${result} =  Run Rspamc  -P  secret  -h  ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER}  neural_learn:spam  ${SPAM_MSG}
+  Check Rspamc  ${result}
+
+  # Learn ham
+  ${result} =  Run Rspamc  -P  secret  -h  ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER}  neural_learn:ham  ${HAM_MSG}
+  Check Rspamc  ${result}
+
+  Sleep  5s
+
+  # Check spam inference (dummy_llm returns ones vector for "spam" content)
+  Scan File  ${SPAM_MSG}  Settings={groups_enabled=["neural"]}
+  Expect Symbol  NEURAL_SPAM
+
+  # Check ham inference (zeros vector)
+  Scan File  ${HAM_MSG}  Settings={groups_enabled=["neural"]}
+  Expect Symbol  NEURAL_HAM
+
+  Dummy Llm Teardown
diff --git a/test/functional/configs/neural_llm.conf b/test/functional/configs/neural_llm.conf
new file mode 100644 (file)
index 0000000..b6745ad
--- /dev/null
@@ -0,0 +1,68 @@
+options = {
+  url_tld = "{= env.URL_TLD =}"
+  pidfile = "{= env.TMPDIR =}/rspamd.pid"
+  lua_path = "{= env.INSTALLROOT =}/share/rspamd/lib/?.lua"
+  filters = [];
+  explicit_modules = ["settings"];
+}
+
+logging = {
+  type = "file",
+  level = "debug"
+  filename = "{= env.TMPDIR =}/rspamd.log"
+  log_usec = true;
+}
+metric = {
+  name = "default",
+  actions = {
+    reject = 100500,
+    add_header = 50500,
+  }
+  unknown_weight = 1
+}
+worker {
+  type = normal
+  bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_NORMAL =}"
+  count = 1
+  task_timeout = 10s;
+}
+worker {
+  type = controller
+  bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_CONTROLLER =}"
+  count = 1
+  secure_ip = ["127.0.0.1", "::1"];
+  stats_path = "{= env.TMPDIR =}/stats.ucl"
+}
+
+modules {
+  path = "{= env.TESTDIR =}/../../src/plugins/lua/"
+}
+
+lua = "{= env.TESTDIR =}/lua/test_coverage.lua";
+
+neural {
+  rules {
+    default {
+      train {
+        learning_rate = 0.001;
+        max_trains = 1;
+        max_iterations = 250;
+      }
+      symbol_spam = "NEURAL_SPAM";
+      symbol_ham = "NEURAL_HAM";
+      ann_expire = 86400;
+      watch_interval = 0.5;
+      providers = [{ type = "llm"; model = "dummy-embed"; url = "http://127.0.0.1:18080"; weight = 1.0; }];
+      fusion { normalization = "none"; }
+      roc_enabled = false;
+    }
+  }
+  allow_local = true;
+}
+
+redis {
+  servers = "{= env.REDIS_ADDR =}:{= env.REDIS_PORT =}";
+  expand_keys = true;
+}
+
+lua = "{= env.TESTDIR =}/lua/neural.lua";
diff --git a/test/functional/util/dummy_llm.py b/test/functional/util/dummy_llm.py
new file mode 100644 (file)
index 0000000..9ee0f17
--- /dev/null
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import json
+import sys
+from http.server import BaseHTTPRequestHandler, HTTPServer
+
+import dummy_killer
+
+PID = "/tmp/dummy_llm.pid"
+
+
+def make_embedding(text: str, dim: int = 32):
+    # Deterministic: if text contains 'SPAM' (case-insensitive) -> ones; else zeros
+    if 'spam' in text.lower():
+        return [1.0] * dim
+    return [0.0] * dim
+
+
+class EmbeddingHandler(BaseHTTPRequestHandler):
+    # OpenAI-like embeddings API
+    def do_POST(self):
+        length = int(self.headers.get('Content-Length', '0'))
+        raw = self.rfile.read(length) if length > 0 else b''
+        try:
+            data = json.loads(raw.decode('utf-8') or '{}')
+        except Exception:
+            data = {}
+
+        # Support both OpenAI ({input, model}) and Ollama ({prompt, model}) shapes
+        text = data.get('input') or data.get('prompt') or ''
+        # Optional dimension override for tests
+        dim = int(data.get('dim') or 32)
+        emb = make_embedding(text, dim)
+
+        if 'openai' in (self.headers.get('User-Agent') or '').lower() or True:
+            # Always reply in OpenAI-like format expected by neural provider
+            body = {
+                "data": [
+                    {"embedding": emb}
+                ]
+            }
+        else:
+            body = {"embedding": emb}
+
+        reply = json.dumps(body).encode('utf-8')
+        self.send_response(200)
+        self.send_header('Content-Type', 'application/json')
+        self.send_header('Content-Length', str(len(reply)))
+        self.end_headers()
+        self.wfile.write(reply)
+
+    def log_message(self, fmt, *args):
+        # Keep test output quiet
+        return
+
+
+if __name__ == "__main__":
+    alen = len(sys.argv)
+    if alen > 1:
+        port = int(sys.argv[1])
+    else:
+        port = 18080
+    server = HTTPServer(("127.0.0.1", port), EmbeddingHandler)
+    dummy_killer.write_pid(PID)
+    try:
+        server.serve_forever()
+    except KeyboardInterrupt:
+        pass
+    finally:
+        server.server_close()