From: Vsevolod Stakhov Date: Mon, 12 Jan 2026 10:58:03 +0000 (+0000) Subject: [Feature] Add expression-based autolearn for neural LLM providers X-Git-Tag: 4.0.0~179^2~30 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cdaf02ca1a81afc5ca4b138c5f5ae297ec9cbf93;p=thirdparty%2Frspamd.git [Feature] Add expression-based autolearn for neural LLM providers Add integrated autolearn system for neural networks with LLM providers: - New lua_neural_learn library with guards system and rspamd_expression support for complex conditions - Expression-based conditions: spam_condition, ham_condition using rspamd_expression syntax (e.g., "BAYES_SPAM & DMARC_POLICY_REJECT") - Score, action, and symbol-based thresholds - Pluggable guards via rspamd_plugins['neural'].autolearn hooks - Mempool-based flag passing (no double scanning) - Probabilistic sampling for training volume control Also includes contrib/neural-embedding-service with a FastEmbed-based Python service for CPU-optimized embedding inference, compatible with both Ollama and OpenAI API formats. Configuration example: autolearn { enabled = true; spam_score = 15.0; spam_condition = "BAYES_SPAM & (DMARC_POLICY_REJECT | RBL_SPAMHAUS)"; ham_condition = "BAYES_HAM & DKIM_VALID_AU & SPF_PASS"; } --- diff --git a/conf/modules.d/neural_autolearn.conf b/conf/modules.d/neural_autolearn.conf new file mode 100644 index 0000000000..846fd7d7a9 --- /dev/null +++ b/conf/modules.d/neural_autolearn.conf @@ -0,0 +1,83 @@ +# Neural Autolearn Configuration +# +# This configuration is part of the neural plugin and controls automatic +# training for neural networks with LLM providers. +# +# The autolearn section can be added to any neural rule configuration. +# It uses expression-based conditions for strong confidence learning. +# +# Documentation: doc/neural-llm-embeddings-guide.md + +# Example autolearn configuration for neural rules +# Add this inside a neural rule: +# +# neural { +# rules { +# llm_classifier { +# providers = [ +# { type = "llm"; llm_type = "ollama"; model = "nomic-embed-text"; } +# ]; +# +# # Autolearn configuration +# autolearn { +# enabled = true; +# +# # Score thresholds +# spam_score = 15.0; # Learn spam if score >= 15.0 +# ham_score = -5.0; # Learn ham if score <= -5.0 +# +# # Action requirements (optional, more restrictive) +# spam_action = "reject"; +# ham_action = "no action"; +# +# # Expression-based conditions (rspamd_expression syntax) +# # These provide fine-grained control over learning decisions +# spam_condition = "BAYES_SPAM & (DMARC_POLICY_REJECT | RBL_SPAMHAUS_SBL)"; +# ham_condition = "BAYES_HAM & DKIM_VALID_AU & SPF_PASS"; +# +# # Required symbols (all must be present) +# spam_symbols = ["BAYES_SPAM"]; +# ham_symbols = ["BAYES_HAM", "DKIM_VALID"]; +# +# # Forbidden symbols (any blocks learning) +# skip_symbols = ["WHITELIST_SENDER", "GREYLIST"]; +# +# # Symbol weight thresholds +# spam_symbol_weight = 5.0; # Sum of spam_symbols scores >= 5.0 +# ham_symbol_weight = -3.0; # Sum of ham_symbols scores <= -3.0 +# +# # Probabilistic sampling (reduce training volume) +# sampling { +# spam_prob = 0.5; # Learn 50% of qualifying spam +# ham_prob = 0.5; +# } +# +# # Skip local/authenticated messages +# check_local = true; +# check_authed = true; +# } +# } +# } +# } + +# Conservative production example: +# autolearn { +# enabled = true; +# spam_score = 20.0; +# ham_score = -8.0; +# spam_action = "reject"; +# spam_condition = "BAYES_SPAM & !WHITELIST_SENDER"; +# ham_condition = "BAYES_HAM & DKIM_VALID_AU"; +# skip_symbols = ["GREYLIST", "RATELIMITED"]; +# sampling { +# spam_prob = 0.3; +# ham_prob = 0.3; +# } +# } + +# Aggressive initial training example: +# autolearn { +# enabled = true; +# spam_score = 10.0; +# ham_score = -2.0; +# } diff --git a/contrib/neural-embedding-service/Dockerfile b/contrib/neural-embedding-service/Dockerfile new file mode 100644 index 0000000000..4083fb831e --- /dev/null +++ b/contrib/neural-embedding-service/Dockerfile @@ -0,0 +1,58 @@ +# Rspamd Neural Embedding Service +# +# CPU-optimized embedding service using FastEmbed + ONNX Runtime +# +# Build: +# docker build -t rspamd-embedding-service . +# +# Run: +# docker run -p 8080:8080 rspamd-embedding-service +# +# With custom model: +# docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service + +FROM python:3.11-slim + +# Build arguments +ARG EMBEDDING_MODEL="BAAI/bge-small-en-v1.5" + +# Environment +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + EMBEDDING_MODEL=${EMBEDDING_MODEL} \ + EMBEDDING_PORT=8080 \ + EMBEDDING_HOST=0.0.0.0 + +# Install system dependencies for ONNX Runtime +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Pre-download model during build (optional, makes container larger but faster startup) +# Uncomment to include model in image: +# RUN python -c "from fastembed import TextEmbedding; TextEmbedding('${EMBEDDING_MODEL}')" + +# Copy application +COPY embedding_service.py . + +# Non-root user +RUN useradd -m -u 1000 embedding && chown -R embedding:embedding /app +USER embedding + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')" + +# Run with uvicorn +CMD ["python", "embedding_service.py"] diff --git a/contrib/neural-embedding-service/README.md b/contrib/neural-embedding-service/README.md new file mode 100644 index 0000000000..5d1635a6d6 --- /dev/null +++ b/contrib/neural-embedding-service/README.md @@ -0,0 +1,187 @@ +# Rspamd Neural Embedding Service + +A lightweight, CPU-optimized embedding service for Rspamd's neural plugin. + +## Overview + +This service provides text embeddings for Rspamd's neural LLM provider. It uses FastEmbed with ONNX Runtime for efficient CPU inference, making it suitable for servers without GPU hardware. + +## Features + +- **CPU-optimized**: Uses ONNX Runtime with INT8 quantization support +- **Lightweight**: ~100MB memory for bge-small-en-v1.5 +- **Compatible**: Supports both Ollama and OpenAI API formats +- **Fast**: 2,500-5,000 sentences/second on modern CPUs + +## Quick Start + +### Option 1: Direct Python + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run service +python embedding_service.py +``` + +### Option 2: Docker + +```bash +# Build +docker build -t rspamd-embedding-service . + +# Run +docker run -p 8080:8080 rspamd-embedding-service + +# With custom model +docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service +``` + +### Option 3: Docker Compose + +See the main guide: `doc/neural-llm-embeddings-guide.md` + +## Configuration + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `EMBEDDING_MODEL` | `BAAI/bge-small-en-v1.5` | FastEmbed model name | +| `EMBEDDING_PORT` | `8080` | Port to listen on | +| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind to | + +### Command Line Arguments + +```bash +python embedding_service.py --help + +Options: + --model, -m Model name (default: BAAI/bge-small-en-v1.5) + --port, -p Port number (default: 8080) + --host, -H Host to bind (default: 0.0.0.0) + --workers, -w Number of workers (default: 1) + --log-level, -l Log level (default: info) +``` + +## API Endpoints + +### Ollama Format (Recommended for Rspamd) + +```bash +curl http://localhost:8080/api/embeddings -d '{ + "model": "BAAI/bge-small-en-v1.5", + "prompt": "Test message about cheap medications" +}' +``` + +Response: +```json +{ + "embedding": [0.123, -0.456, ...] +} +``` + +### OpenAI Format + +```bash +curl http://localhost:8080/v1/embeddings -d '{ + "model": "BAAI/bge-small-en-v1.5", + "input": "Test message" +}' +``` + +Response: +```json +{ + "object": "list", + "data": [{"embedding": [...], "index": 0}], + "model": "BAAI/bge-small-en-v1.5", + "usage": {"prompt_tokens": 2, "total_tokens": 2} +} +``` + +### Health Check + +```bash +curl http://localhost:8080/health +``` + +## Rspamd Configuration + +Configure Rspamd to use this service: + +```hcl +# /etc/rspamd/local.d/neural.conf +rules { + default { + providers = [ + { + type = "llm"; + llm_type = "ollama"; # or "openai" + model = "BAAI/bge-small-en-v1.5"; + url = "http://localhost:8080/api/embeddings"; # or /v1/embeddings + timeout = 2.0; + cache_ttl = 86400; + } + ]; + # ... + } +} +``` + +## Supported Models + +| Model | Size | Dims | Quality | Speed | +|-------|------|------|---------|-------| +| `BAAI/bge-small-en-v1.5` | 33MB | 384 | Good | Excellent | +| `BAAI/bge-base-en-v1.5` | 440MB | 768 | Better | Good | +| `sentence-transformers/all-MiniLM-L6-v2` | 90MB | 384 | Fair | Excellent | +| `intfloat/e5-small-v2` | 200MB | 384 | Good | Excellent | + +For the full list, see: https://qdrant.github.io/fastembed/examples/Supported_Models/ + +## Production Deployment + +### With Gunicorn + +```bash +pip install gunicorn +gunicorn -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8080 embedding_service:app +``` + +### Resource Recommendations + +| Model | Memory | CPU Cores | +|-------|--------|-----------| +| bge-small-en-v1.5 | 256MB | 1 | +| bge-base-en-v1.5 | 1GB | 2 | + +## Troubleshooting + +### Model download issues + +```bash +# Pre-download model +python -c "from fastembed import TextEmbedding; TextEmbedding('BAAI/bge-small-en-v1.5')" +``` + +### Memory issues + +Use a smaller model: +```bash +EMBEDDING_MODEL="BAAI/bge-small-en-v1.5" python embedding_service.py +``` + +### Slow inference + +- Ensure ONNX Runtime is using optimized providers +- Consider increasing workers for parallel processing +- Use batching for bulk operations + +## License + +Apache License 2.0 + +See the main Rspamd repository for license details. diff --git a/contrib/neural-embedding-service/embedding_service.py b/contrib/neural-embedding-service/embedding_service.py new file mode 100644 index 0000000000..50daceeb28 --- /dev/null +++ b/contrib/neural-embedding-service/embedding_service.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +""" +Lightweight embedding service for Rspamd neural plugin. + +Uses FastEmbed with ONNX for CPU-optimized inference. +Provides both Ollama-compatible and OpenAI-compatible endpoints. + +Installation: + pip install fastapi uvicorn fastembed pydantic + +Usage: + python embedding_service.py [--model MODEL] [--port PORT] [--host HOST] + + # Default: bge-small-en-v1.5 on port 8080 + python embedding_service.py + + # Custom model + python embedding_service.py --model "BAAI/bge-base-en-v1.5" + + # Production with gunicorn + gunicorn -w 4 -k uvicorn.workers.UvicornWorker embedding_service:app + +Environment variables: + EMBEDDING_MODEL: Model name (default: BAAI/bge-small-en-v1.5) + EMBEDDING_PORT: Port number (default: 8080) + EMBEDDING_HOST: Host to bind (default: 0.0.0.0) +""" + +import argparse +import logging +import os +import time +from contextlib import asynccontextmanager +from typing import List, Optional, Union + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +# FastEmbed - CPU-optimized ONNX inference +from fastembed import TextEmbedding + +# Logging setup +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Configuration from environment +DEFAULT_MODEL = os.getenv('EMBEDDING_MODEL', 'BAAI/bge-small-en-v1.5') +DEFAULT_PORT = int(os.getenv('EMBEDDING_PORT', '8080')) +DEFAULT_HOST = os.getenv('EMBEDDING_HOST', '0.0.0.0') + +# Global model instance +model: Optional[TextEmbedding] = None +model_name: str = DEFAULT_MODEL +model_dim: int = 0 + + +# Request/Response models +class OllamaEmbeddingRequest(BaseModel): + """Ollama-compatible embedding request.""" + model: str = DEFAULT_MODEL + prompt: str + + +class OllamaEmbeddingResponse(BaseModel): + """Ollama-compatible embedding response.""" + embedding: List[float] + + +class OpenAIEmbeddingRequest(BaseModel): + """OpenAI-compatible embedding request.""" + model: str = DEFAULT_MODEL + input: Union[str, List[str]] + encoding_format: str = "float" + + +class OpenAIEmbeddingData(BaseModel): + """OpenAI embedding data object.""" + object: str = "embedding" + embedding: List[float] + index: int + + +class OpenAIUsage(BaseModel): + """OpenAI usage object.""" + prompt_tokens: int + total_tokens: int + + +class OpenAIEmbeddingResponse(BaseModel): + """OpenAI-compatible embedding response.""" + object: str = "list" + data: List[OpenAIEmbeddingData] + model: str + usage: OpenAIUsage + + +class HealthResponse(BaseModel): + """Health check response.""" + status: str + model: str + dimensions: int + uptime_seconds: float + + +# Startup time for uptime calculation +startup_time: float = 0.0 + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model on startup.""" + global model, model_name, model_dim, startup_time + + logger.info(f"Loading embedding model: {model_name}") + start = time.time() + + try: + model = TextEmbedding(model_name) + # Get embedding dimension from a test inference + test_embed = list(model.embed(["test"]))[0] + model_dim = len(test_embed) + elapsed = time.time() - start + logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}") + startup_time = time.time() + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + + yield + + logger.info("Shutting down embedding service") + + +app = FastAPI( + title="Rspamd Embedding Service", + description="CPU-optimized embedding service for Rspamd neural plugin", + version="1.0.0", + lifespan=lifespan, +) + + +def get_embedding(text: str) -> List[float]: + """Generate embedding for a single text.""" + if model is None: + raise HTTPException(500, "Model not loaded") + + embeddings = list(model.embed([text])) + return embeddings[0].tolist() + + +def get_embeddings_batch(texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple texts.""" + if model is None: + raise HTTPException(500, "Model not loaded") + + embeddings = list(model.embed(texts)) + return [e.tolist() for e in embeddings] + + +def count_tokens(text: str) -> int: + """Approximate token count (words * 1.3).""" + return int(len(text.split()) * 1.3) + + +@app.post("/api/embeddings", response_model=OllamaEmbeddingResponse) +async def ollama_embeddings(request: OllamaEmbeddingRequest) -> OllamaEmbeddingResponse: + """ + Ollama-compatible embedding endpoint. + + Used by Rspamd neural LLM provider with llm_type = "ollama". + """ + if not request.prompt: + raise HTTPException(400, "Missing prompt") + + logger.debug(f"Ollama request: {len(request.prompt)} chars") + embedding = get_embedding(request.prompt) + + return OllamaEmbeddingResponse(embedding=embedding) + + +@app.post("/v1/embeddings", response_model=OpenAIEmbeddingResponse) +async def openai_embeddings(request: OpenAIEmbeddingRequest) -> OpenAIEmbeddingResponse: + """ + OpenAI-compatible embedding endpoint. + + Used by Rspamd neural LLM provider with llm_type = "openai". + """ + if not request.input: + raise HTTPException(400, "Missing input") + + # Handle single string or list of strings + if isinstance(request.input, str): + texts = [request.input] + else: + texts = request.input + + logger.debug(f"OpenAI request: {len(texts)} texts") + embeddings = get_embeddings_batch(texts) + + # Build response + data = [ + OpenAIEmbeddingData(embedding=emb, index=i) + for i, emb in enumerate(embeddings) + ] + + total_tokens = sum(count_tokens(t) for t in texts) + + return OpenAIEmbeddingResponse( + data=data, + model=request.model, + usage=OpenAIUsage(prompt_tokens=total_tokens, total_tokens=total_tokens) + ) + + +@app.get("/health", response_model=HealthResponse) +@app.get("/", response_model=HealthResponse) +async def health() -> HealthResponse: + """Health check endpoint.""" + return HealthResponse( + status="ok" if model is not None else "loading", + model=model_name, + dimensions=model_dim, + uptime_seconds=time.time() - startup_time if startup_time > 0 else 0 + ) + + +@app.get("/v1/models") +async def list_models(): + """List available models (OpenAI-compatible).""" + return { + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": int(startup_time), + "owned_by": "fastembed", + "permission": [], + "root": model_name, + } + ] + } + + +def main(): + """Run the embedding service.""" + global model_name + + parser = argparse.ArgumentParser( + description="Rspamd embedding service", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model", "-m", + default=DEFAULT_MODEL, + help="FastEmbed model name" + ) + parser.add_argument( + "--port", "-p", + type=int, + default=DEFAULT_PORT, + help="Port to listen on" + ) + parser.add_argument( + "--host", "-H", + default=DEFAULT_HOST, + help="Host to bind to" + ) + parser.add_argument( + "--workers", "-w", + type=int, + default=1, + help="Number of worker processes" + ) + parser.add_argument( + "--log-level", "-l", + default="info", + choices=["debug", "info", "warning", "error"], + help="Log level" + ) + + args = parser.parse_args() + model_name = args.model + + import uvicorn + uvicorn.run( + "embedding_service:app", + host=args.host, + port=args.port, + workers=args.workers, + log_level=args.log_level, + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/neural-embedding-service/requirements.txt b/contrib/neural-embedding-service/requirements.txt new file mode 100644 index 0000000000..da81649dfe --- /dev/null +++ b/contrib/neural-embedding-service/requirements.txt @@ -0,0 +1,19 @@ +# Rspamd Neural Embedding Service Dependencies +# +# Install: pip install -r requirements.txt + +# FastAPI web framework +fastapi>=0.100.0 + +# ASGI server +uvicorn[standard]>=0.23.0 + +# FastEmbed - CPU-optimized embedding inference via ONNX +# Includes: onnxruntime, numpy, tokenizers +fastembed>=0.2.0 + +# Data validation +pydantic>=2.0.0 + +# Optional: production ASGI server +# gunicorn>=21.0.0 diff --git a/lualib/lua_neural_learn.lua b/lualib/lua_neural_learn.lua new file mode 100644 index 0000000000..c04f6c53b4 --- /dev/null +++ b/lualib/lua_neural_learn.lua @@ -0,0 +1,491 @@ +--[[ +Copyright (c) 2024, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ +Neural network autolearn helpers. + +This module provides configurable autolearn conditions for neural networks, +particularly useful for LLM-based providers where automatic learning needs +careful control. + +Similar to lua_bayes_learn.lua, this provides: +- Guards system for pluggable checks +- Expression-based conditions (rspamd_expression) +- Score/action/symbol-based thresholds +- Hooks for custom logic via rspamd_plugins + +Usage in neural.lua: + local neural_learn = require "lua_neural_learn" + local can_learn, reason = neural_learn.can_autolearn(task, rule, 'spam') +]]-- + +local lua_util = require "lua_util" +local rspamd_expression = require "rspamd_expression" +local rspamd_logger = require "rspamd_logger" + +local N = "lua_neural_learn" + +local exports = {} + +-- Global defaults that can be overridden via configure() +local global_defaults = {} + +-- Registered guards (callbacks that can block learning) +local autolearn_guards = {} + +-- Cached compiled expressions per rule +local expression_cache = {} + +-- Default autolearn settings +local default_autolearn_settings = { + -- Master enable/disable + enabled = false, + + -- Require minimum score magnitude for learning + spam_score = nil, -- e.g., 6.0 - learn spam if score >= 6.0 + ham_score = nil, -- e.g., -2.0 - learn ham if score <= -2.0 + + -- Require specific actions + spam_action = nil, -- e.g., 'reject' - only learn spam on reject + ham_action = nil, -- e.g., 'no action' - only learn ham on no action + + -- Expression-based conditions (rspamd_expression syntax) + -- Examples: + -- "BAYES_SPAM & !WHITELIST_SENDER" + -- "DMARC_POLICY_REJECT | (RBL_SPAMHAUS_SBL & SURBL_MULTI)" + spam_condition = nil, + ham_condition = nil, + + -- Required symbols (all must be present) + spam_symbols = nil, -- e.g., {'BAYES_SPAM', 'DKIM_VALID'} + ham_symbols = nil, + + -- Forbidden symbols (any blocks learning) + skip_symbols = nil, -- e.g., {'WHITELIST_SENDER', 'GREYLIST'} + + -- Minimum symbol weight sum + spam_symbol_weight = nil, -- e.g., 5.0 - sum of spam_symbols scores >= 5.0 + ham_symbol_weight = nil, -- e.g., -3.0 - sum of ham_symbols scores <= -3.0 + + -- Probability-based check (skip if already confident) + probability_check = { + enabled = false, + variable = 'neural_prob', -- mempool variable name + spam_min = 0.95, -- skip if already 95% spam + ham_max = 0.05, -- skip if already 95% ham + }, + + -- Rate limiting + rate_limit = { + enabled = false, + max_daily = 1000, -- per class per day + redis_prefix = 'neural_autolearn', + }, + + -- Sampling (probabilistic training reduction) + sampling = { + spam_prob = 1.0, -- 1.0 = always, 0.5 = 50% chance + ham_prob = 1.0, + }, + + -- Skip conditions + check_local = false, -- skip local network messages + check_authed = false, -- skip authenticated users +} + +-- Helper: convert array to set +local function as_set(tbl) + if not tbl then + return nil + end + local res = {} + for _, v in ipairs(tbl) do + res[v] = true + end + return res +end + +-- Helper: merge options with defaults +local function merge_options(defaults, overrides) + local merged = lua_util.override_defaults(defaults, global_defaults) + if overrides then + merged = lua_util.override_defaults(merged, overrides) + end + return merged +end + +-- Guard execution +local function execute_guards(task, learn_type, ctx) + for _, guard in ipairs(autolearn_guards) do + local ok, reason = guard.cb(task, learn_type, ctx) + if not ok then + return false, reason or guard.name + end + end + return true +end + +--- Register a guard callback for autolearn decisions +-- @param name string guard name +-- @param cb function(task, learn_type, ctx) -> bool, reason +-- @param opts table optional {priority = number} +function exports.register_guard(name, cb, opts) + if type(name) == 'function' then + cb = name + name = string.format('guard_%d', #autolearn_guards + 1) + end + + if type(cb) ~= 'function' then + rspamd_logger.errx(rspamd_config, '%s: guard callback must be a function', N) + return nil + end + + local guard = { + name = name, + cb = cb, + priority = opts and opts.priority or 0, + } + + autolearn_guards[#autolearn_guards + 1] = guard + table.sort(autolearn_guards, function(a, b) + return (a.priority or 0) > (b.priority or 0) + end) + + lua_util.debugm(N, rspamd_config, 'registered autolearn guard: %s', name) + return name +end + +--- Unregister a guard by name +function exports.unregister_guard(name) + for i = #autolearn_guards, 1, -1 do + if autolearn_guards[i].name == name then + table.remove(autolearn_guards, i) + return true + end + end + return false +end + +--- Configure global defaults +-- @param opts table of default overrides +function exports.configure(opts) + if opts then + global_defaults = lua_util.override_defaults(global_defaults, opts) + lua_util.debugm(N, rspamd_config, 'configured neural autolearn defaults') + end +end + +-- Compile and cache expression +local function get_expression(rule_name, expr_str, pool) + local cache_key = rule_name .. ':' .. expr_str + if expression_cache[cache_key] then + return expression_cache[cache_key] + end + + local function parse_atom(str) + local atom = '' + for c in str:gmatch('.') do + if c:match('[%w_]') then + atom = atom .. c + else + break + end + end + return atom + end + + local function process_atom(atom, task) + if task:has_symbol(atom) then + local sym = task:get_symbol(atom) + if sym and sym[1] then + local score = math.abs(sym[1].score or 0) + return score > 0.001 and score or 0.001 + end + return 0.001 + end + return 0 + end + + local expr, err = rspamd_expression.create(expr_str, { parse_atom, process_atom }, pool) + if err then + rspamd_logger.errx(rspamd_config, '%s: cannot create expression [%s]: %s', N, expr_str, err) + return nil + end + + expression_cache[cache_key] = expr + return expr +end + +-- Check if all required symbols are present +local function check_required_symbols(task, symbols) + if not symbols or #symbols == 0 then + return true + end + for _, sym in ipairs(symbols) do + if not task:has_symbol(sym) then + return false, string.format('missing required symbol: %s', sym) + end + end + return true +end + +-- Check if any forbidden symbols are present +local function check_forbidden_symbols(task, symbols) + if not symbols then + return true + end + local skip_set = as_set(symbols) + if not skip_set then + return true + end + for sym, _ in pairs(skip_set) do + if task:has_symbol(sym) then + return false, string.format('has forbidden symbol: %s', sym) + end + end + return true +end + +-- Calculate sum of symbol scores +local function get_symbols_weight(task, symbols) + if not symbols or #symbols == 0 then + return 0 + end + local total = 0 + for _, sym in ipairs(symbols) do + local s = task:get_symbol(sym) + if s and s[1] then + total = total + (s[1].score or 0) + end + end + return total +end + +--- Main function: determine if a message should be autolearned +-- @param task rspamd_task +-- @param rule neural rule configuration +-- @param learn_type 'spam' or 'ham' +-- @param overrides optional per-call config overrides +-- @return bool can_learn, string reason +function exports.can_autolearn(task, rule, learn_type, overrides) + local autolearn_opts = rule.autolearn or {} + local opts = merge_options(default_autolearn_settings, autolearn_opts) + + if overrides then + opts = merge_options(opts, overrides) + end + + -- Master enable check + if not opts.enabled then + return false, 'autolearn disabled' + end + + local score = task:get_metric_score()[1] + local action = task:get_metric_action() + + local ctx = { + task = task, + rule = rule, + learn_type = learn_type, + score = score, + action = action, + options = opts, + } + + -- Execute registered guards first + local guard_ok, guard_reason = execute_guards(task, learn_type, ctx) + if not guard_ok then + return false, string.format('blocked by guard: %s', guard_reason) + end + + -- Skip checks + if opts.check_local and task:get_from_ip() and task:get_from_ip():is_local() then + return false, 'local network message' + end + + if opts.check_authed and task:get_user() then + return false, 'authenticated user' + end + + -- Forbidden symbols check + local skip_ok, skip_reason = check_forbidden_symbols(task, opts.skip_symbols) + if not skip_ok then + return false, skip_reason + end + + -- Learn type specific checks + if learn_type == 'spam' then + -- Score threshold + if opts.spam_score and score < opts.spam_score then + return false, string.format('score %.2f < spam_score %.2f', score, opts.spam_score) + end + + -- Action requirement + if opts.spam_action and action ~= opts.spam_action then + return false, string.format('action %s != required %s', action, opts.spam_action) + end + + -- Required symbols + local sym_ok, sym_reason = check_required_symbols(task, opts.spam_symbols) + if not sym_ok then + return false, sym_reason + end + + -- Symbol weight threshold + if opts.spam_symbol_weight then + local weight = get_symbols_weight(task, opts.spam_symbols) + if weight < opts.spam_symbol_weight then + return false, string.format('spam symbol weight %.2f < %.2f', weight, opts.spam_symbol_weight) + end + end + + -- Expression condition + if opts.spam_condition then + local expr = get_expression(rule.prefix or 'default', opts.spam_condition, rspamd_config:get_mempool()) + if expr then + local result = expr:process(task) + if result <= 0 then + return false, string.format('spam_condition not satisfied: %s', opts.spam_condition) + end + end + end + + elseif learn_type == 'ham' then + -- Score threshold + if opts.ham_score and score > opts.ham_score then + return false, string.format('score %.2f > ham_score %.2f', score, opts.ham_score) + end + + -- Action requirement + if opts.ham_action and action ~= opts.ham_action then + return false, string.format('action %s != required %s', action, opts.ham_action) + end + + -- Required symbols + local sym_ok, sym_reason = check_required_symbols(task, opts.ham_symbols) + if not sym_ok then + return false, sym_reason + end + + -- Symbol weight threshold + if opts.ham_symbol_weight then + local weight = get_symbols_weight(task, opts.ham_symbols) + if weight > opts.ham_symbol_weight then + return false, string.format('ham symbol weight %.2f > %.2f', weight, opts.ham_symbol_weight) + end + end + + -- Expression condition + if opts.ham_condition then + local expr = get_expression(rule.prefix or 'default', opts.ham_condition, rspamd_config:get_mempool()) + if expr then + local result = expr:process(task) + if result <= 0 then + return false, string.format('ham_condition not satisfied: %s', opts.ham_condition) + end + end + end + end + + -- Probability check (skip if already confident) + if opts.probability_check and opts.probability_check.enabled then + local prob_var = opts.probability_check.variable or 'neural_prob' + local prob = task:get_mempool():get_variable(prob_var, 'double') + if prob then + if learn_type == 'spam' and prob >= opts.probability_check.spam_min then + return false, string.format('already confident spam: %.2f >= %.2f', prob, opts.probability_check.spam_min) + elseif learn_type == 'ham' and prob <= opts.probability_check.ham_max then + return false, string.format('already confident ham: %.2f <= %.2f', prob, opts.probability_check.ham_max) + end + end + end + + -- Probabilistic sampling + if opts.sampling then + local sample_prob = learn_type == 'spam' and opts.sampling.spam_prob or opts.sampling.ham_prob + if sample_prob and sample_prob < 1.0 then + local coin = math.random() + if coin > sample_prob then + return false, string.format('sampled out: %.2f > %.2f', coin, sample_prob) + end + end + end + + return true, nil +end + +--- Determine learn type based on score/action/symbols +-- @param task rspamd_task +-- @param rule neural rule configuration +-- @return string learn_type ('spam', 'ham', or nil), string reason +function exports.get_learn_type(task, rule) + local autolearn_opts = rule.autolearn or {} + local opts = merge_options(default_autolearn_settings, autolearn_opts) + + if not opts.enabled then + return nil, 'autolearn disabled' + end + + -- Try spam first + local spam_ok, spam_reason = exports.can_autolearn(task, rule, 'spam') + if spam_ok then + return 'spam', 'autolearn spam' + end + + -- Try ham + local ham_ok, ham_reason = exports.can_autolearn(task, rule, 'ham') + if ham_ok then + return 'ham', 'autolearn ham' + end + + -- Neither qualifies + return nil, spam_reason or ham_reason or 'no autolearn condition matched' +end + +--- Set autolearn class in mempool (for integration with neural.lua) +-- @param task rspamd_task +-- @param learn_type 'spam' or 'ham' +function exports.set_autolearn_class(task, learn_type) + task:get_mempool():set_variable('neural_autolearn_class', learn_type) + lua_util.debugm(N, task, 'set neural autolearn class: %s', learn_type) +end + +--- Get autolearn class from mempool +-- @param task rspamd_task +-- @return string learn_type or nil +function exports.get_autolearn_class(task) + return task:get_mempool():get_variable('neural_autolearn_class') +end + +--- Clear expression cache (useful for config reload) +function exports.clear_cache() + expression_cache = {} +end + +-- Register module in rspamd_plugins for user hooks +if rspamd_plugins then + rspamd_plugins['neural_learn'] = { + register_guard = exports.register_guard, + unregister_guard = exports.unregister_guard, + configure = exports.configure, + can_autolearn = exports.can_autolearn, + get_learn_type = exports.get_learn_type, + set_autolearn_class = exports.set_autolearn_class, + get_autolearn_class = exports.get_autolearn_class, + } +end + +return exports diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 8d9859f8a9..44450c9d69 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -20,6 +20,7 @@ local lua_redis = require "lua_redis" local lua_util = require "lua_util" local lua_verdict = require "lua_verdict" local neural_common = require "plugins/neural" +local neural_learn = require "lua_neural_learn" local rspamd_kann = require "rspamd_kann" local rspamd_logger = require "rspamd_logger" local rspamd_tensor = require "rspamd_tensor" @@ -217,20 +218,55 @@ local function ann_push_task_result(rule, task, verdict, score, set) end end - -- If LLM provider is configured, suppress autotrain unless manual training requested - if not manual_train and rule.providers and #rule.providers > 0 then + -- Check for autolearn class set by mempool (integration with external learning decisions) + if not manual_train then + local autolearn_class = neural_learn.get_autolearn_class(task) + if autolearn_class then + lua_util.debugm(N, task, 'found neural autolearn class in mempool: %s', autolearn_class) + if autolearn_class == 'spam' then + learn_spam = true + manual_train = true + elseif autolearn_class == 'ham' then + learn_ham = true + manual_train = true + end + end + end + + -- If LLM provider is configured, use autolearn conditions instead of simple score thresholds + local has_llm_provider = false + if rule.providers and #rule.providers > 0 then for _, p in ipairs(rule.providers) do if p.type == 'llm' then - lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header') - learn_spam = false - learn_ham = false - skip_reason = 'llm provider requires manual training' + has_llm_provider = true break end end end - if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then + if has_llm_provider and not manual_train then + -- Use expression-based autolearn conditions for LLM providers + if rule.autolearn and rule.autolearn.enabled then + local learn_type, reason = neural_learn.get_learn_type(task, rule) + if learn_type == 'spam' then + learn_spam = true + lua_util.debugm(N, task, 'autolearn spam via expression: %s', reason) + elseif learn_type == 'ham' then + learn_ham = true + lua_util.debugm(N, task, 'autolearn ham via expression: %s', reason) + else + skip_reason = reason or 'autolearn condition not met' + lua_util.debugm(N, task, 'autolearn skip: %s', skip_reason) + end + else + -- LLM provider without autolearn config - require manual training + learn_spam = false + learn_ham = false + skip_reason = 'llm provider requires autolearn config or manual training' + lua_util.debugm(N, task, 'suppress autotrain: llm provider present, no autolearn config') + end + elseif not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then + -- Traditional score/verdict based learning for non-LLM providers if train_opts.spam_score then learn_spam = score >= train_opts.spam_score @@ -261,8 +297,8 @@ local function ann_push_task_result(rule, task, verdict, score, set) verdict) end end - else - if train_opts.store_pool_only and not manual_train then + elseif not manual_train then + if train_opts.store_pool_only then local ucl = require "ucl" learn_ham = false learn_spam = false @@ -1120,3 +1156,28 @@ for _, rule in pairs(settings.rules) do end end) end + +-- Register plugin API in rspamd_plugins for user hooks +if rspamd_plugins then + rspamd_plugins['neural'] = rspamd_plugins['neural'] or {} + -- Expose autolearn hooks for user customization + rspamd_plugins['neural'].autolearn = { + -- Register a custom guard that can block learning + -- cb: function(task, learn_type, ctx) -> bool, reason + register_guard = neural_learn.register_guard, + -- Remove a registered guard + unregister_guard = neural_learn.unregister_guard, + -- Configure global autolearn defaults + configure = neural_learn.configure, + -- Check if task qualifies for autolearn + -- Returns: can_learn (bool), reason (string) + can_autolearn = neural_learn.can_autolearn, + -- Get learn type for task based on conditions + -- Returns: 'spam', 'ham', or nil + get_learn_type = neural_learn.get_learn_type, + -- Set autolearn class in mempool (triggers learning in idempotent callback) + set_autolearn_class = neural_learn.set_autolearn_class, + -- Get autolearn class from mempool + get_autolearn_class = neural_learn.get_autolearn_class, + } +end