]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add skeleton for embeddings module vstakhov-llm-embeddings 5264/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 19 Dec 2024 16:52:27 +0000 (16:52 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 19 Dec 2024 16:52:27 +0000 (16:52 +0000)
conf/modules.d/llm_embeddings.conf [new file with mode: 0644]
src/plugins/lua/llm_embeddings.lua [new file with mode: 0644]

diff --git a/conf/modules.d/llm_embeddings.conf b/conf/modules.d/llm_embeddings.conf
new file mode 100644 (file)
index 0000000..d5c91bb
--- /dev/null
@@ -0,0 +1,33 @@
+# Please don't modify this file as your changes might be overwritten with
+# the next update.
+#
+# You can modify 'local.d/gpt.conf' to add and merge
+# parameters defined inside this section
+#
+# You can modify 'override.d/gpt.conf' to strictly override all
+# parameters defined inside this section
+#
+# See https://rspamd.com/doc/faq.html#what-are-the-locald-and-overrided-directories
+# for details
+#
+# Module documentation can be found at  https://rspamd.com/doc/modules/gpt.html
+
+llm_embeddings {
+  # Supported types: openai, ollama
+  type = "ollama";
+  model = "nomic-embed-text";
+  dimensions = 8192;
+  # Maximum tokens to generate
+  max_tokens = 5000;
+  # Timeout for requests
+  timeout = 10s;
+  # URL to check (e.g. for ollama)
+  url = "http://localhost:11434/api/embeddings";
+  # Be sure to enable module after you specify the API key
+  enabled = false;
+
+  # Include dynamic conf for the rule
+  .include(try=true,priority=5) "${DBDIR}/dynamic/llm_embeddings.conf"
+  .include(try=true,priority=1,duplicate=merge) "$LOCAL_CONFDIR/local.d/llm_embeddings.conf"
+  .include(try=true,priority=10) "$LOCAL_CONFDIR/override.d/llm_embeddings.conf"
+}
\ No newline at end of file
diff --git a/src/plugins/lua/llm_embeddings.lua b/src/plugins/lua/llm_embeddings.lua
new file mode 100644 (file)
index 0000000..d591b4d
--- /dev/null
@@ -0,0 +1,295 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]] --
+
+local N = "llm_embeddings"
+
+if confighelp then
+  rspamd_config:add_example(nil, N,
+      "Performs statistical analysis of messages using LLM for embeddings and NN for classification",
+      [[
+llm_embeddings {
+  # Supported types: openai, ollama
+  type = "ollama";
+  # Your key to access the API
+  api_key = "xxx";
+  # Model name
+  model = "nomic-embed-text";
+  # Check the documentation for the model for this value
+  dimensions = 8192;
+  # Maximum tokens to generate
+  max_tokens = 1000;
+  # URL for the API
+  url = "http://localhost:11434/api/embeddings";
+  # Redis parameters to save the resulting classifier
+  servers = "localhost:6379";
+  # Prefix for keys
+  prefix = "llm";
+  # How many learns are required to start classifying
+  min_learns = 100;
+  # Check messages with passthrough result
+  allow_passthrough = false;
+  # Check messages that are apparent ham (no action and negative score)
+  allow_ham = false;
+}
+  ]])
+  return
+end
+
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local lua_mime = require "lua_mime"
+local ucl = require "ucl"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_tensor = require "rspamd_tensor"
+local lua_redis = require "lua_redis"
+
+local settings = {
+  type = 'ollama',
+  api_key = nil,
+  model = 'gpt-4o-mini',
+  max_tokens = 5000,
+  timeout = 10,
+  prompt = nil,
+  condition = nil,
+  autolearn = false,
+  url = 'http://localhost:11434/api/embeddings',
+  allow_passthrough = false,
+  allow_ham = false,
+  dimensions = 8192,
+  hidden_layer_mult = 0.5, -- Compress in hidden layer
+}
+
+local has_blas = rspamd_tensor.has_blas()
+local kann_model
+local model_learns = 0
+local redis_params
+
+local function extract_data(task)
+  -- Check result
+  -- 1) Skip passthrough
+  local result = task:get_metric_result()
+  if result then
+    if result.passthrough and not settings.allow_passthrough then
+      return false, 'passthrough'
+    end
+  end
+
+  -- Check if we have text at all
+  local sel_part = lua_mime.get_displayed_text_part(task)
+
+  if not sel_part then
+    return false, 'no text part found'
+  end
+
+  -- Check limits and size sanity
+  local nwords = sel_part:get_words_count()
+
+  if nwords < 5 then
+    return false, 'less than 5 words'
+  end
+
+  if nwords > settings.max_tokens then
+    -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`)
+    local words = sel_part:get_words('norm')
+    nwords = #words
+    if nwords > settings.max_tokens then
+      return true, table.concat(words, ' ', 1, settings.max_tokens)
+    end
+  end
+  return true, sel_part:get_content_oneline()
+end
+
+local function gen_embeddings_ollama(task, continuation_cb)
+  local condition, content = extract_data(task)
+  if not condition then
+    return
+  end
+
+  local function embeddings_cb(err, code, data)
+    if err then
+      rspamd_logger.errx(task, 'cannot get embeddings: %s', err)
+      return
+    end
+
+    if data then
+      lua_util.debugm(N, task, 'got reply from embeddings model: %s', data)
+      local parser = ucl.parser()
+      local res, err = parser:parse_string(data)
+      if not res then
+        rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+        return
+      end
+      local reply = parser:get_object()
+
+      if reply and type(reply) == 'table' and type(reply.embedding) == 'table' then
+        lua_util.debugm(N, task, 'got embeddings: %s', #reply.embedding)
+        continuation_cb(task, reply.embedding)
+      else
+        rspamd_logger.errx(task, 'cannot parse embeddings: %s', data)
+      end
+    end
+  end
+
+  local post_data = {
+    model = settings.model,
+    prompt = content,
+  }
+
+  rspamd_http.request({
+    url = settings.url,
+    task = task,
+    callback = embeddings_cb,
+    body = ucl.to_json(post_data),
+    timeout = settings.timeout,
+    headers = {
+      ['Authorization'] = settings.api_key,
+      ['Content-Type'] = 'application/json',
+    },
+  })
+end
+
+local function kann_model_create()
+  local t = rspamd_kann.layer.input(settings.dimensions)
+  t = rspamd_kann.transform.relu(t)
+  t = rspamd_kann.layer.dense(t, settings.dimensions * settings.hidden_layer_mult);
+  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+  kann_model = rspamd_kann.new.kann(t)
+end
+
+local function redis_prefix()
+  return settings.prefix .. '_' .. settings.model
+end
+
+local function kann_model_save(ev_base)
+  if not redis_params then
+    return
+  end
+
+  local function save_cb(err, _)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot save model: %s', err)
+    end
+  end
+
+  local packed_model = kann_model:save()
+  local key = redis_prefix() .. '_model'
+  lua_redis.redis_make_request_taskless(ev_base, rspamd_config,
+      redis_params, key, true,
+      save_cb, 'SET', { key, packed_model })
+end
+
+local function kann_model_maybe_load(ev_base)
+  if not redis_params then
+    return
+  end
+
+  local function load_cb(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot load model: %s', err)
+    else
+      if data then
+        kann_model = rspamd_kann.load(data)
+      end
+    end
+  end
+
+  local key = redis_prefix() .. '_model'
+  lua_redis.redis_make_request_taskless(ev_base, rspamd_config,
+      redis_params, key, false,
+      load_cb, 'GET', { key })
+end
+
+local function save_embeddings_vector(task, is_spam)
+  local function save_cb(err, _)
+    if err then
+      rspamd_logger.errx(task, 'cannot save embeddings: %s', err)
+    end
+  end
+
+  local function save_vector(emb)
+    local key = redis_prefix() .. (is_spam and '_spam' or '_ham')
+    local packed_vector = ucl.to_format(emb, 'msgpack')
+    lua_redis.redis_make_request(task,
+        redis_params, key, true,
+        save_cb, 'LPUSH', { key, packed_vector })
+  end
+
+  gen_embeddings_ollama(task, save_vector)
+end
+
+local function nn_learn(task, is_spam)
+  if not kann_model then
+    kann_model_create()
+  end
+
+  save_embeddings_vector(task, is_spam)
+end
+
+local function nn_learn_spam(task)
+  lua_util.debugm(N, task, 'learn spam')
+  nn_learn(task, true)
+end
+
+local function nn_learn_ham(task)
+  lua_util.debugm(N, task, 'learn ham')
+  nn_learn(task, false)
+end
+
+local function nn_classify(task)
+  -- TODO: Implement
+end
+
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+redis_params = lua_redis.parse_redis_server(N)
+
+if not redis_params then
+  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
+  lua_util.disable_module(N, "redis")
+  return
+end
+
+local id = rspamd_config:register_symbol {
+  name = "LLM_CLASSIFY_CHECK",
+  type = 'callback',
+  callback = nn_classify,
+}
+rspamd_config:register_symbol {
+  name = "LLM_EMBEDDINGS_SPAM",
+  type = 'virtual',
+  parent = id,
+}
+rspamd_config:register_symbol {
+  name = "LLM_EMBEDDINGS_HAM",
+  type = 'virtual',
+  parent = id,
+}
+
+-- Allow this symbol to be enabled merely explicitly when we need to learn
+rspamd_config:register_symbol {
+  name = "LLM_LEARN_SPAM",
+  type = 'callback',
+  callback = nn_learn_spam,
+  flags = 'explicit_enable',
+}
+-- Allow this symbol to be enabled merely explicitly when we need to learn
+rspamd_config:register_symbol {
+  name = "LLM_LEARN_HAM",
+  type = 'callback',
+  callback = nn_learn_ham,
+  flags = 'explicit_enable',
+}
\ No newline at end of file