--- /dev/null
+# Neural Autolearn Configuration
+#
+# This configuration is part of the neural plugin and controls automatic
+# training for neural networks with LLM providers.
+#
+# The autolearn section can be added to any neural rule configuration.
+# It uses expression-based conditions for strong confidence learning.
+#
+# Documentation: doc/neural-llm-embeddings-guide.md
+
+# Example autolearn configuration for neural rules
+# Add this inside a neural rule:
+#
+# neural {
+# rules {
+# llm_classifier {
+# providers = [
+# { type = "llm"; llm_type = "ollama"; model = "nomic-embed-text"; }
+# ];
+#
+# # Autolearn configuration
+# autolearn {
+# enabled = true;
+#
+# # Score thresholds
+# spam_score = 15.0; # Learn spam if score >= 15.0
+# ham_score = -5.0; # Learn ham if score <= -5.0
+#
+# # Action requirements (optional, more restrictive)
+# spam_action = "reject";
+# ham_action = "no action";
+#
+# # Expression-based conditions (rspamd_expression syntax)
+# # These provide fine-grained control over learning decisions
+# spam_condition = "BAYES_SPAM & (DMARC_POLICY_REJECT | RBL_SPAMHAUS_SBL)";
+# ham_condition = "BAYES_HAM & DKIM_VALID_AU & SPF_PASS";
+#
+# # Required symbols (all must be present)
+# spam_symbols = ["BAYES_SPAM"];
+# ham_symbols = ["BAYES_HAM", "DKIM_VALID"];
+#
+# # Forbidden symbols (any blocks learning)
+# skip_symbols = ["WHITELIST_SENDER", "GREYLIST"];
+#
+# # Symbol weight thresholds
+# spam_symbol_weight = 5.0; # Sum of spam_symbols scores >= 5.0
+# ham_symbol_weight = -3.0; # Sum of ham_symbols scores <= -3.0
+#
+# # Probabilistic sampling (reduce training volume)
+# sampling {
+# spam_prob = 0.5; # Learn 50% of qualifying spam
+# ham_prob = 0.5;
+# }
+#
+# # Skip local/authenticated messages
+# check_local = true;
+# check_authed = true;
+# }
+# }
+# }
+# }
+
+# Conservative production example:
+# autolearn {
+# enabled = true;
+# spam_score = 20.0;
+# ham_score = -8.0;
+# spam_action = "reject";
+# spam_condition = "BAYES_SPAM & !WHITELIST_SENDER";
+# ham_condition = "BAYES_HAM & DKIM_VALID_AU";
+# skip_symbols = ["GREYLIST", "RATELIMITED"];
+# sampling {
+# spam_prob = 0.3;
+# ham_prob = 0.3;
+# }
+# }
+
+# Aggressive initial training example:
+# autolearn {
+# enabled = true;
+# spam_score = 10.0;
+# ham_score = -2.0;
+# }
--- /dev/null
+# Rspamd Neural Embedding Service
+#
+# CPU-optimized embedding service using FastEmbed + ONNX Runtime
+#
+# Build:
+# docker build -t rspamd-embedding-service .
+#
+# Run:
+# docker run -p 8080:8080 rspamd-embedding-service
+#
+# With custom model:
+# docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service
+
+FROM python:3.11-slim
+
+# Build arguments
+ARG EMBEDDING_MODEL="BAAI/bge-small-en-v1.5"
+
+# Environment
+ENV PYTHONUNBUFFERED=1 \
+ PYTHONDONTWRITEBYTECODE=1 \
+ PIP_NO_CACHE_DIR=1 \
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
+ EMBEDDING_MODEL=${EMBEDDING_MODEL} \
+ EMBEDDING_PORT=8080 \
+ EMBEDDING_HOST=0.0.0.0
+
+# Install system dependencies for ONNX Runtime
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ libgomp1 \
+ && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /app
+
+# Install Python dependencies
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Pre-download model during build (optional, makes container larger but faster startup)
+# Uncomment to include model in image:
+# RUN python -c "from fastembed import TextEmbedding; TextEmbedding('${EMBEDDING_MODEL}')"
+
+# Copy application
+COPY embedding_service.py .
+
+# Non-root user
+RUN useradd -m -u 1000 embedding && chown -R embedding:embedding /app
+USER embedding
+
+# Expose port
+EXPOSE 8080
+
+# Health check
+HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"
+
+# Run with uvicorn
+CMD ["python", "embedding_service.py"]
--- /dev/null
+# Rspamd Neural Embedding Service
+
+A lightweight, CPU-optimized embedding service for Rspamd's neural plugin.
+
+## Overview
+
+This service provides text embeddings for Rspamd's neural LLM provider. It uses FastEmbed with ONNX Runtime for efficient CPU inference, making it suitable for servers without GPU hardware.
+
+## Features
+
+- **CPU-optimized**: Uses ONNX Runtime with INT8 quantization support
+- **Lightweight**: ~100MB memory for bge-small-en-v1.5
+- **Compatible**: Supports both Ollama and OpenAI API formats
+- **Fast**: 2,500-5,000 sentences/second on modern CPUs
+
+## Quick Start
+
+### Option 1: Direct Python
+
+```bash
+# Install dependencies
+pip install -r requirements.txt
+
+# Run service
+python embedding_service.py
+```
+
+### Option 2: Docker
+
+```bash
+# Build
+docker build -t rspamd-embedding-service .
+
+# Run
+docker run -p 8080:8080 rspamd-embedding-service
+
+# With custom model
+docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service
+```
+
+### Option 3: Docker Compose
+
+See the main guide: `doc/neural-llm-embeddings-guide.md`
+
+## Configuration
+
+### Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `EMBEDDING_MODEL` | `BAAI/bge-small-en-v1.5` | FastEmbed model name |
+| `EMBEDDING_PORT` | `8080` | Port to listen on |
+| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind to |
+
+### Command Line Arguments
+
+```bash
+python embedding_service.py --help
+
+Options:
+ --model, -m Model name (default: BAAI/bge-small-en-v1.5)
+ --port, -p Port number (default: 8080)
+ --host, -H Host to bind (default: 0.0.0.0)
+ --workers, -w Number of workers (default: 1)
+ --log-level, -l Log level (default: info)
+```
+
+## API Endpoints
+
+### Ollama Format (Recommended for Rspamd)
+
+```bash
+curl http://localhost:8080/api/embeddings -d '{
+ "model": "BAAI/bge-small-en-v1.5",
+ "prompt": "Test message about cheap medications"
+}'
+```
+
+Response:
+```json
+{
+ "embedding": [0.123, -0.456, ...]
+}
+```
+
+### OpenAI Format
+
+```bash
+curl http://localhost:8080/v1/embeddings -d '{
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": "Test message"
+}'
+```
+
+Response:
+```json
+{
+ "object": "list",
+ "data": [{"embedding": [...], "index": 0}],
+ "model": "BAAI/bge-small-en-v1.5",
+ "usage": {"prompt_tokens": 2, "total_tokens": 2}
+}
+```
+
+### Health Check
+
+```bash
+curl http://localhost:8080/health
+```
+
+## Rspamd Configuration
+
+Configure Rspamd to use this service:
+
+```hcl
+# /etc/rspamd/local.d/neural.conf
+rules {
+ default {
+ providers = [
+ {
+ type = "llm";
+ llm_type = "ollama"; # or "openai"
+ model = "BAAI/bge-small-en-v1.5";
+ url = "http://localhost:8080/api/embeddings"; # or /v1/embeddings
+ timeout = 2.0;
+ cache_ttl = 86400;
+ }
+ ];
+ # ...
+ }
+}
+```
+
+## Supported Models
+
+| Model | Size | Dims | Quality | Speed |
+|-------|------|------|---------|-------|
+| `BAAI/bge-small-en-v1.5` | 33MB | 384 | Good | Excellent |
+| `BAAI/bge-base-en-v1.5` | 440MB | 768 | Better | Good |
+| `sentence-transformers/all-MiniLM-L6-v2` | 90MB | 384 | Fair | Excellent |
+| `intfloat/e5-small-v2` | 200MB | 384 | Good | Excellent |
+
+For the full list, see: https://qdrant.github.io/fastembed/examples/Supported_Models/
+
+## Production Deployment
+
+### With Gunicorn
+
+```bash
+pip install gunicorn
+gunicorn -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8080 embedding_service:app
+```
+
+### Resource Recommendations
+
+| Model | Memory | CPU Cores |
+|-------|--------|-----------|
+| bge-small-en-v1.5 | 256MB | 1 |
+| bge-base-en-v1.5 | 1GB | 2 |
+
+## Troubleshooting
+
+### Model download issues
+
+```bash
+# Pre-download model
+python -c "from fastembed import TextEmbedding; TextEmbedding('BAAI/bge-small-en-v1.5')"
+```
+
+### Memory issues
+
+Use a smaller model:
+```bash
+EMBEDDING_MODEL="BAAI/bge-small-en-v1.5" python embedding_service.py
+```
+
+### Slow inference
+
+- Ensure ONNX Runtime is using optimized providers
+- Consider increasing workers for parallel processing
+- Use batching for bulk operations
+
+## License
+
+Apache License 2.0
+
+See the main Rspamd repository for license details.
--- /dev/null
+#!/usr/bin/env python3
+"""
+Lightweight embedding service for Rspamd neural plugin.
+
+Uses FastEmbed with ONNX for CPU-optimized inference.
+Provides both Ollama-compatible and OpenAI-compatible endpoints.
+
+Installation:
+ pip install fastapi uvicorn fastembed pydantic
+
+Usage:
+ python embedding_service.py [--model MODEL] [--port PORT] [--host HOST]
+
+ # Default: bge-small-en-v1.5 on port 8080
+ python embedding_service.py
+
+ # Custom model
+ python embedding_service.py --model "BAAI/bge-base-en-v1.5"
+
+ # Production with gunicorn
+ gunicorn -w 4 -k uvicorn.workers.UvicornWorker embedding_service:app
+
+Environment variables:
+ EMBEDDING_MODEL: Model name (default: BAAI/bge-small-en-v1.5)
+ EMBEDDING_PORT: Port number (default: 8080)
+ EMBEDDING_HOST: Host to bind (default: 0.0.0.0)
+"""
+
+import argparse
+import logging
+import os
+import time
+from contextlib import asynccontextmanager
+from typing import List, Optional, Union
+
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+
+# FastEmbed - CPU-optimized ONNX inference
+from fastembed import TextEmbedding
+
+# Logging setup
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+# Configuration from environment
+DEFAULT_MODEL = os.getenv('EMBEDDING_MODEL', 'BAAI/bge-small-en-v1.5')
+DEFAULT_PORT = int(os.getenv('EMBEDDING_PORT', '8080'))
+DEFAULT_HOST = os.getenv('EMBEDDING_HOST', '0.0.0.0')
+
+# Global model instance
+model: Optional[TextEmbedding] = None
+model_name: str = DEFAULT_MODEL
+model_dim: int = 0
+
+
+# Request/Response models
+class OllamaEmbeddingRequest(BaseModel):
+ """Ollama-compatible embedding request."""
+ model: str = DEFAULT_MODEL
+ prompt: str
+
+
+class OllamaEmbeddingResponse(BaseModel):
+ """Ollama-compatible embedding response."""
+ embedding: List[float]
+
+
+class OpenAIEmbeddingRequest(BaseModel):
+ """OpenAI-compatible embedding request."""
+ model: str = DEFAULT_MODEL
+ input: Union[str, List[str]]
+ encoding_format: str = "float"
+
+
+class OpenAIEmbeddingData(BaseModel):
+ """OpenAI embedding data object."""
+ object: str = "embedding"
+ embedding: List[float]
+ index: int
+
+
+class OpenAIUsage(BaseModel):
+ """OpenAI usage object."""
+ prompt_tokens: int
+ total_tokens: int
+
+
+class OpenAIEmbeddingResponse(BaseModel):
+ """OpenAI-compatible embedding response."""
+ object: str = "list"
+ data: List[OpenAIEmbeddingData]
+ model: str
+ usage: OpenAIUsage
+
+
+class HealthResponse(BaseModel):
+ """Health check response."""
+ status: str
+ model: str
+ dimensions: int
+ uptime_seconds: float
+
+
+# Startup time for uptime calculation
+startup_time: float = 0.0
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Load model on startup."""
+ global model, model_name, model_dim, startup_time
+
+ logger.info(f"Loading embedding model: {model_name}")
+ start = time.time()
+
+ try:
+ model = TextEmbedding(model_name)
+ # Get embedding dimension from a test inference
+ test_embed = list(model.embed(["test"]))[0]
+ model_dim = len(test_embed)
+ elapsed = time.time() - start
+ logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
+ startup_time = time.time()
+ except Exception as e:
+ logger.error(f"Failed to load model: {e}")
+ raise
+
+ yield
+
+ logger.info("Shutting down embedding service")
+
+
+app = FastAPI(
+ title="Rspamd Embedding Service",
+ description="CPU-optimized embedding service for Rspamd neural plugin",
+ version="1.0.0",
+ lifespan=lifespan,
+)
+
+
+def get_embedding(text: str) -> List[float]:
+ """Generate embedding for a single text."""
+ if model is None:
+ raise HTTPException(500, "Model not loaded")
+
+ embeddings = list(model.embed([text]))
+ return embeddings[0].tolist()
+
+
+def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
+ """Generate embeddings for multiple texts."""
+ if model is None:
+ raise HTTPException(500, "Model not loaded")
+
+ embeddings = list(model.embed(texts))
+ return [e.tolist() for e in embeddings]
+
+
+def count_tokens(text: str) -> int:
+ """Approximate token count (words * 1.3)."""
+ return int(len(text.split()) * 1.3)
+
+
+@app.post("/api/embeddings", response_model=OllamaEmbeddingResponse)
+async def ollama_embeddings(request: OllamaEmbeddingRequest) -> OllamaEmbeddingResponse:
+ """
+ Ollama-compatible embedding endpoint.
+
+ Used by Rspamd neural LLM provider with llm_type = "ollama".
+ """
+ if not request.prompt:
+ raise HTTPException(400, "Missing prompt")
+
+ logger.debug(f"Ollama request: {len(request.prompt)} chars")
+ embedding = get_embedding(request.prompt)
+
+ return OllamaEmbeddingResponse(embedding=embedding)
+
+
+@app.post("/v1/embeddings", response_model=OpenAIEmbeddingResponse)
+async def openai_embeddings(request: OpenAIEmbeddingRequest) -> OpenAIEmbeddingResponse:
+ """
+ OpenAI-compatible embedding endpoint.
+
+ Used by Rspamd neural LLM provider with llm_type = "openai".
+ """
+ if not request.input:
+ raise HTTPException(400, "Missing input")
+
+ # Handle single string or list of strings
+ if isinstance(request.input, str):
+ texts = [request.input]
+ else:
+ texts = request.input
+
+ logger.debug(f"OpenAI request: {len(texts)} texts")
+ embeddings = get_embeddings_batch(texts)
+
+ # Build response
+ data = [
+ OpenAIEmbeddingData(embedding=emb, index=i)
+ for i, emb in enumerate(embeddings)
+ ]
+
+ total_tokens = sum(count_tokens(t) for t in texts)
+
+ return OpenAIEmbeddingResponse(
+ data=data,
+ model=request.model,
+ usage=OpenAIUsage(prompt_tokens=total_tokens, total_tokens=total_tokens)
+ )
+
+
+@app.get("/health", response_model=HealthResponse)
+@app.get("/", response_model=HealthResponse)
+async def health() -> HealthResponse:
+ """Health check endpoint."""
+ return HealthResponse(
+ status="ok" if model is not None else "loading",
+ model=model_name,
+ dimensions=model_dim,
+ uptime_seconds=time.time() - startup_time if startup_time > 0 else 0
+ )
+
+
+@app.get("/v1/models")
+async def list_models():
+ """List available models (OpenAI-compatible)."""
+ return {
+ "object": "list",
+ "data": [
+ {
+ "id": model_name,
+ "object": "model",
+ "created": int(startup_time),
+ "owned_by": "fastembed",
+ "permission": [],
+ "root": model_name,
+ }
+ ]
+ }
+
+
+def main():
+ """Run the embedding service."""
+ global model_name
+
+ parser = argparse.ArgumentParser(
+ description="Rspamd embedding service",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--model", "-m",
+ default=DEFAULT_MODEL,
+ help="FastEmbed model name"
+ )
+ parser.add_argument(
+ "--port", "-p",
+ type=int,
+ default=DEFAULT_PORT,
+ help="Port to listen on"
+ )
+ parser.add_argument(
+ "--host", "-H",
+ default=DEFAULT_HOST,
+ help="Host to bind to"
+ )
+ parser.add_argument(
+ "--workers", "-w",
+ type=int,
+ default=1,
+ help="Number of worker processes"
+ )
+ parser.add_argument(
+ "--log-level", "-l",
+ default="info",
+ choices=["debug", "info", "warning", "error"],
+ help="Log level"
+ )
+
+ args = parser.parse_args()
+ model_name = args.model
+
+ import uvicorn
+ uvicorn.run(
+ "embedding_service:app",
+ host=args.host,
+ port=args.port,
+ workers=args.workers,
+ log_level=args.log_level,
+ )
+
+
+if __name__ == "__main__":
+ main()
--- /dev/null
+# Rspamd Neural Embedding Service Dependencies
+#
+# Install: pip install -r requirements.txt
+
+# FastAPI web framework
+fastapi>=0.100.0
+
+# ASGI server
+uvicorn[standard]>=0.23.0
+
+# FastEmbed - CPU-optimized embedding inference via ONNX
+# Includes: onnxruntime, numpy, tokenizers
+fastembed>=0.2.0
+
+# Data validation
+pydantic>=2.0.0
+
+# Optional: production ASGI server
+# gunicorn>=21.0.0
--- /dev/null
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+Neural network autolearn helpers.
+
+This module provides configurable autolearn conditions for neural networks,
+particularly useful for LLM-based providers where automatic learning needs
+careful control.
+
+Similar to lua_bayes_learn.lua, this provides:
+- Guards system for pluggable checks
+- Expression-based conditions (rspamd_expression)
+- Score/action/symbol-based thresholds
+- Hooks for custom logic via rspamd_plugins
+
+Usage in neural.lua:
+ local neural_learn = require "lua_neural_learn"
+ local can_learn, reason = neural_learn.can_autolearn(task, rule, 'spam')
+]]--
+
+local lua_util = require "lua_util"
+local rspamd_expression = require "rspamd_expression"
+local rspamd_logger = require "rspamd_logger"
+
+local N = "lua_neural_learn"
+
+local exports = {}
+
+-- Global defaults that can be overridden via configure()
+local global_defaults = {}
+
+-- Registered guards (callbacks that can block learning)
+local autolearn_guards = {}
+
+-- Cached compiled expressions per rule
+local expression_cache = {}
+
+-- Default autolearn settings
+local default_autolearn_settings = {
+ -- Master enable/disable
+ enabled = false,
+
+ -- Require minimum score magnitude for learning
+ spam_score = nil, -- e.g., 6.0 - learn spam if score >= 6.0
+ ham_score = nil, -- e.g., -2.0 - learn ham if score <= -2.0
+
+ -- Require specific actions
+ spam_action = nil, -- e.g., 'reject' - only learn spam on reject
+ ham_action = nil, -- e.g., 'no action' - only learn ham on no action
+
+ -- Expression-based conditions (rspamd_expression syntax)
+ -- Examples:
+ -- "BAYES_SPAM & !WHITELIST_SENDER"
+ -- "DMARC_POLICY_REJECT | (RBL_SPAMHAUS_SBL & SURBL_MULTI)"
+ spam_condition = nil,
+ ham_condition = nil,
+
+ -- Required symbols (all must be present)
+ spam_symbols = nil, -- e.g., {'BAYES_SPAM', 'DKIM_VALID'}
+ ham_symbols = nil,
+
+ -- Forbidden symbols (any blocks learning)
+ skip_symbols = nil, -- e.g., {'WHITELIST_SENDER', 'GREYLIST'}
+
+ -- Minimum symbol weight sum
+ spam_symbol_weight = nil, -- e.g., 5.0 - sum of spam_symbols scores >= 5.0
+ ham_symbol_weight = nil, -- e.g., -3.0 - sum of ham_symbols scores <= -3.0
+
+ -- Probability-based check (skip if already confident)
+ probability_check = {
+ enabled = false,
+ variable = 'neural_prob', -- mempool variable name
+ spam_min = 0.95, -- skip if already 95% spam
+ ham_max = 0.05, -- skip if already 95% ham
+ },
+
+ -- Rate limiting
+ rate_limit = {
+ enabled = false,
+ max_daily = 1000, -- per class per day
+ redis_prefix = 'neural_autolearn',
+ },
+
+ -- Sampling (probabilistic training reduction)
+ sampling = {
+ spam_prob = 1.0, -- 1.0 = always, 0.5 = 50% chance
+ ham_prob = 1.0,
+ },
+
+ -- Skip conditions
+ check_local = false, -- skip local network messages
+ check_authed = false, -- skip authenticated users
+}
+
+-- Helper: convert array to set
+local function as_set(tbl)
+ if not tbl then
+ return nil
+ end
+ local res = {}
+ for _, v in ipairs(tbl) do
+ res[v] = true
+ end
+ return res
+end
+
+-- Helper: merge options with defaults
+local function merge_options(defaults, overrides)
+ local merged = lua_util.override_defaults(defaults, global_defaults)
+ if overrides then
+ merged = lua_util.override_defaults(merged, overrides)
+ end
+ return merged
+end
+
+-- Guard execution
+local function execute_guards(task, learn_type, ctx)
+ for _, guard in ipairs(autolearn_guards) do
+ local ok, reason = guard.cb(task, learn_type, ctx)
+ if not ok then
+ return false, reason or guard.name
+ end
+ end
+ return true
+end
+
+--- Register a guard callback for autolearn decisions
+-- @param name string guard name
+-- @param cb function(task, learn_type, ctx) -> bool, reason
+-- @param opts table optional {priority = number}
+function exports.register_guard(name, cb, opts)
+ if type(name) == 'function' then
+ cb = name
+ name = string.format('guard_%d', #autolearn_guards + 1)
+ end
+
+ if type(cb) ~= 'function' then
+ rspamd_logger.errx(rspamd_config, '%s: guard callback must be a function', N)
+ return nil
+ end
+
+ local guard = {
+ name = name,
+ cb = cb,
+ priority = opts and opts.priority or 0,
+ }
+
+ autolearn_guards[#autolearn_guards + 1] = guard
+ table.sort(autolearn_guards, function(a, b)
+ return (a.priority or 0) > (b.priority or 0)
+ end)
+
+ lua_util.debugm(N, rspamd_config, 'registered autolearn guard: %s', name)
+ return name
+end
+
+--- Unregister a guard by name
+function exports.unregister_guard(name)
+ for i = #autolearn_guards, 1, -1 do
+ if autolearn_guards[i].name == name then
+ table.remove(autolearn_guards, i)
+ return true
+ end
+ end
+ return false
+end
+
+--- Configure global defaults
+-- @param opts table of default overrides
+function exports.configure(opts)
+ if opts then
+ global_defaults = lua_util.override_defaults(global_defaults, opts)
+ lua_util.debugm(N, rspamd_config, 'configured neural autolearn defaults')
+ end
+end
+
+-- Compile and cache expression
+local function get_expression(rule_name, expr_str, pool)
+ local cache_key = rule_name .. ':' .. expr_str
+ if expression_cache[cache_key] then
+ return expression_cache[cache_key]
+ end
+
+ local function parse_atom(str)
+ local atom = ''
+ for c in str:gmatch('.') do
+ if c:match('[%w_]') then
+ atom = atom .. c
+ else
+ break
+ end
+ end
+ return atom
+ end
+
+ local function process_atom(atom, task)
+ if task:has_symbol(atom) then
+ local sym = task:get_symbol(atom)
+ if sym and sym[1] then
+ local score = math.abs(sym[1].score or 0)
+ return score > 0.001 and score or 0.001
+ end
+ return 0.001
+ end
+ return 0
+ end
+
+ local expr, err = rspamd_expression.create(expr_str, { parse_atom, process_atom }, pool)
+ if err then
+ rspamd_logger.errx(rspamd_config, '%s: cannot create expression [%s]: %s', N, expr_str, err)
+ return nil
+ end
+
+ expression_cache[cache_key] = expr
+ return expr
+end
+
+-- Check if all required symbols are present
+local function check_required_symbols(task, symbols)
+ if not symbols or #symbols == 0 then
+ return true
+ end
+ for _, sym in ipairs(symbols) do
+ if not task:has_symbol(sym) then
+ return false, string.format('missing required symbol: %s', sym)
+ end
+ end
+ return true
+end
+
+-- Check if any forbidden symbols are present
+local function check_forbidden_symbols(task, symbols)
+ if not symbols then
+ return true
+ end
+ local skip_set = as_set(symbols)
+ if not skip_set then
+ return true
+ end
+ for sym, _ in pairs(skip_set) do
+ if task:has_symbol(sym) then
+ return false, string.format('has forbidden symbol: %s', sym)
+ end
+ end
+ return true
+end
+
+-- Calculate sum of symbol scores
+local function get_symbols_weight(task, symbols)
+ if not symbols or #symbols == 0 then
+ return 0
+ end
+ local total = 0
+ for _, sym in ipairs(symbols) do
+ local s = task:get_symbol(sym)
+ if s and s[1] then
+ total = total + (s[1].score or 0)
+ end
+ end
+ return total
+end
+
+--- Main function: determine if a message should be autolearned
+-- @param task rspamd_task
+-- @param rule neural rule configuration
+-- @param learn_type 'spam' or 'ham'
+-- @param overrides optional per-call config overrides
+-- @return bool can_learn, string reason
+function exports.can_autolearn(task, rule, learn_type, overrides)
+ local autolearn_opts = rule.autolearn or {}
+ local opts = merge_options(default_autolearn_settings, autolearn_opts)
+
+ if overrides then
+ opts = merge_options(opts, overrides)
+ end
+
+ -- Master enable check
+ if not opts.enabled then
+ return false, 'autolearn disabled'
+ end
+
+ local score = task:get_metric_score()[1]
+ local action = task:get_metric_action()
+
+ local ctx = {
+ task = task,
+ rule = rule,
+ learn_type = learn_type,
+ score = score,
+ action = action,
+ options = opts,
+ }
+
+ -- Execute registered guards first
+ local guard_ok, guard_reason = execute_guards(task, learn_type, ctx)
+ if not guard_ok then
+ return false, string.format('blocked by guard: %s', guard_reason)
+ end
+
+ -- Skip checks
+ if opts.check_local and task:get_from_ip() and task:get_from_ip():is_local() then
+ return false, 'local network message'
+ end
+
+ if opts.check_authed and task:get_user() then
+ return false, 'authenticated user'
+ end
+
+ -- Forbidden symbols check
+ local skip_ok, skip_reason = check_forbidden_symbols(task, opts.skip_symbols)
+ if not skip_ok then
+ return false, skip_reason
+ end
+
+ -- Learn type specific checks
+ if learn_type == 'spam' then
+ -- Score threshold
+ if opts.spam_score and score < opts.spam_score then
+ return false, string.format('score %.2f < spam_score %.2f', score, opts.spam_score)
+ end
+
+ -- Action requirement
+ if opts.spam_action and action ~= opts.spam_action then
+ return false, string.format('action %s != required %s', action, opts.spam_action)
+ end
+
+ -- Required symbols
+ local sym_ok, sym_reason = check_required_symbols(task, opts.spam_symbols)
+ if not sym_ok then
+ return false, sym_reason
+ end
+
+ -- Symbol weight threshold
+ if opts.spam_symbol_weight then
+ local weight = get_symbols_weight(task, opts.spam_symbols)
+ if weight < opts.spam_symbol_weight then
+ return false, string.format('spam symbol weight %.2f < %.2f', weight, opts.spam_symbol_weight)
+ end
+ end
+
+ -- Expression condition
+ if opts.spam_condition then
+ local expr = get_expression(rule.prefix or 'default', opts.spam_condition, rspamd_config:get_mempool())
+ if expr then
+ local result = expr:process(task)
+ if result <= 0 then
+ return false, string.format('spam_condition not satisfied: %s', opts.spam_condition)
+ end
+ end
+ end
+
+ elseif learn_type == 'ham' then
+ -- Score threshold
+ if opts.ham_score and score > opts.ham_score then
+ return false, string.format('score %.2f > ham_score %.2f', score, opts.ham_score)
+ end
+
+ -- Action requirement
+ if opts.ham_action and action ~= opts.ham_action then
+ return false, string.format('action %s != required %s', action, opts.ham_action)
+ end
+
+ -- Required symbols
+ local sym_ok, sym_reason = check_required_symbols(task, opts.ham_symbols)
+ if not sym_ok then
+ return false, sym_reason
+ end
+
+ -- Symbol weight threshold
+ if opts.ham_symbol_weight then
+ local weight = get_symbols_weight(task, opts.ham_symbols)
+ if weight > opts.ham_symbol_weight then
+ return false, string.format('ham symbol weight %.2f > %.2f', weight, opts.ham_symbol_weight)
+ end
+ end
+
+ -- Expression condition
+ if opts.ham_condition then
+ local expr = get_expression(rule.prefix or 'default', opts.ham_condition, rspamd_config:get_mempool())
+ if expr then
+ local result = expr:process(task)
+ if result <= 0 then
+ return false, string.format('ham_condition not satisfied: %s', opts.ham_condition)
+ end
+ end
+ end
+ end
+
+ -- Probability check (skip if already confident)
+ if opts.probability_check and opts.probability_check.enabled then
+ local prob_var = opts.probability_check.variable or 'neural_prob'
+ local prob = task:get_mempool():get_variable(prob_var, 'double')
+ if prob then
+ if learn_type == 'spam' and prob >= opts.probability_check.spam_min then
+ return false, string.format('already confident spam: %.2f >= %.2f', prob, opts.probability_check.spam_min)
+ elseif learn_type == 'ham' and prob <= opts.probability_check.ham_max then
+ return false, string.format('already confident ham: %.2f <= %.2f', prob, opts.probability_check.ham_max)
+ end
+ end
+ end
+
+ -- Probabilistic sampling
+ if opts.sampling then
+ local sample_prob = learn_type == 'spam' and opts.sampling.spam_prob or opts.sampling.ham_prob
+ if sample_prob and sample_prob < 1.0 then
+ local coin = math.random()
+ if coin > sample_prob then
+ return false, string.format('sampled out: %.2f > %.2f', coin, sample_prob)
+ end
+ end
+ end
+
+ return true, nil
+end
+
+--- Determine learn type based on score/action/symbols
+-- @param task rspamd_task
+-- @param rule neural rule configuration
+-- @return string learn_type ('spam', 'ham', or nil), string reason
+function exports.get_learn_type(task, rule)
+ local autolearn_opts = rule.autolearn or {}
+ local opts = merge_options(default_autolearn_settings, autolearn_opts)
+
+ if not opts.enabled then
+ return nil, 'autolearn disabled'
+ end
+
+ -- Try spam first
+ local spam_ok, spam_reason = exports.can_autolearn(task, rule, 'spam')
+ if spam_ok then
+ return 'spam', 'autolearn spam'
+ end
+
+ -- Try ham
+ local ham_ok, ham_reason = exports.can_autolearn(task, rule, 'ham')
+ if ham_ok then
+ return 'ham', 'autolearn ham'
+ end
+
+ -- Neither qualifies
+ return nil, spam_reason or ham_reason or 'no autolearn condition matched'
+end
+
+--- Set autolearn class in mempool (for integration with neural.lua)
+-- @param task rspamd_task
+-- @param learn_type 'spam' or 'ham'
+function exports.set_autolearn_class(task, learn_type)
+ task:get_mempool():set_variable('neural_autolearn_class', learn_type)
+ lua_util.debugm(N, task, 'set neural autolearn class: %s', learn_type)
+end
+
+--- Get autolearn class from mempool
+-- @param task rspamd_task
+-- @return string learn_type or nil
+function exports.get_autolearn_class(task)
+ return task:get_mempool():get_variable('neural_autolearn_class')
+end
+
+--- Clear expression cache (useful for config reload)
+function exports.clear_cache()
+ expression_cache = {}
+end
+
+-- Register module in rspamd_plugins for user hooks
+if rspamd_plugins then
+ rspamd_plugins['neural_learn'] = {
+ register_guard = exports.register_guard,
+ unregister_guard = exports.unregister_guard,
+ configure = exports.configure,
+ can_autolearn = exports.can_autolearn,
+ get_learn_type = exports.get_learn_type,
+ set_autolearn_class = exports.set_autolearn_class,
+ get_autolearn_class = exports.get_autolearn_class,
+ }
+end
+
+return exports
local lua_util = require "lua_util"
local lua_verdict = require "lua_verdict"
local neural_common = require "plugins/neural"
+local neural_learn = require "lua_neural_learn"
local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
end
end
- -- If LLM provider is configured, suppress autotrain unless manual training requested
- if not manual_train and rule.providers and #rule.providers > 0 then
+ -- Check for autolearn class set by mempool (integration with external learning decisions)
+ if not manual_train then
+ local autolearn_class = neural_learn.get_autolearn_class(task)
+ if autolearn_class then
+ lua_util.debugm(N, task, 'found neural autolearn class in mempool: %s', autolearn_class)
+ if autolearn_class == 'spam' then
+ learn_spam = true
+ manual_train = true
+ elseif autolearn_class == 'ham' then
+ learn_ham = true
+ manual_train = true
+ end
+ end
+ end
+
+ -- If LLM provider is configured, use autolearn conditions instead of simple score thresholds
+ local has_llm_provider = false
+ if rule.providers and #rule.providers > 0 then
for _, p in ipairs(rule.providers) do
if p.type == 'llm' then
- lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header')
- learn_spam = false
- learn_ham = false
- skip_reason = 'llm provider requires manual training'
+ has_llm_provider = true
break
end
end
end
- if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
+ if has_llm_provider and not manual_train then
+ -- Use expression-based autolearn conditions for LLM providers
+ if rule.autolearn and rule.autolearn.enabled then
+ local learn_type, reason = neural_learn.get_learn_type(task, rule)
+ if learn_type == 'spam' then
+ learn_spam = true
+ lua_util.debugm(N, task, 'autolearn spam via expression: %s', reason)
+ elseif learn_type == 'ham' then
+ learn_ham = true
+ lua_util.debugm(N, task, 'autolearn ham via expression: %s', reason)
+ else
+ skip_reason = reason or 'autolearn condition not met'
+ lua_util.debugm(N, task, 'autolearn skip: %s', skip_reason)
+ end
+ else
+ -- LLM provider without autolearn config - require manual training
+ learn_spam = false
+ learn_ham = false
+ skip_reason = 'llm provider requires autolearn config or manual training'
+ lua_util.debugm(N, task, 'suppress autotrain: llm provider present, no autolearn config')
+ end
+ elseif not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
+ -- Traditional score/verdict based learning for non-LLM providers
if train_opts.spam_score then
learn_spam = score >= train_opts.spam_score
verdict)
end
end
- else
- if train_opts.store_pool_only and not manual_train then
+ elseif not manual_train then
+ if train_opts.store_pool_only then
local ucl = require "ucl"
learn_ham = false
learn_spam = false
end
end)
end
+
+-- Register plugin API in rspamd_plugins for user hooks
+if rspamd_plugins then
+ rspamd_plugins['neural'] = rspamd_plugins['neural'] or {}
+ -- Expose autolearn hooks for user customization
+ rspamd_plugins['neural'].autolearn = {
+ -- Register a custom guard that can block learning
+ -- cb: function(task, learn_type, ctx) -> bool, reason
+ register_guard = neural_learn.register_guard,
+ -- Remove a registered guard
+ unregister_guard = neural_learn.unregister_guard,
+ -- Configure global autolearn defaults
+ configure = neural_learn.configure,
+ -- Check if task qualifies for autolearn
+ -- Returns: can_learn (bool), reason (string)
+ can_autolearn = neural_learn.can_autolearn,
+ -- Get learn type for task based on conditions
+ -- Returns: 'spam', 'ham', or nil
+ get_learn_type = neural_learn.get_learn_type,
+ -- Set autolearn class in mempool (triggers learning in idempotent callback)
+ set_autolearn_class = neural_learn.set_autolearn_class,
+ -- Get autolearn class from mempool
+ get_autolearn_class = neural_learn.get_autolearn_class,
+ }
+end