]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add expression-based autolearn for neural LLM providers
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 12 Jan 2026 10:58:03 +0000 (10:58 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 12 Jan 2026 10:58:03 +0000 (10:58 +0000)
Add integrated autolearn system for neural networks with LLM providers:

- New lua_neural_learn library with guards system and rspamd_expression
  support for complex conditions
- Expression-based conditions: spam_condition, ham_condition using
  rspamd_expression syntax (e.g., "BAYES_SPAM & DMARC_POLICY_REJECT")
- Score, action, and symbol-based thresholds
- Pluggable guards via rspamd_plugins['neural'].autolearn hooks
- Mempool-based flag passing (no double scanning)
- Probabilistic sampling for training volume control

Also includes contrib/neural-embedding-service with a FastEmbed-based
Python service for CPU-optimized embedding inference, compatible with
both Ollama and OpenAI API formats.

Configuration example:
  autolearn {
    enabled = true;
    spam_score = 15.0;
    spam_condition = "BAYES_SPAM & (DMARC_POLICY_REJECT | RBL_SPAMHAUS)";
    ham_condition = "BAYES_HAM & DKIM_VALID_AU & SPF_PASS";
  }

conf/modules.d/neural_autolearn.conf [new file with mode: 0644]
contrib/neural-embedding-service/Dockerfile [new file with mode: 0644]
contrib/neural-embedding-service/README.md [new file with mode: 0644]
contrib/neural-embedding-service/embedding_service.py [new file with mode: 0644]
contrib/neural-embedding-service/requirements.txt [new file with mode: 0644]
lualib/lua_neural_learn.lua [new file with mode: 0644]
src/plugins/lua/neural.lua

diff --git a/conf/modules.d/neural_autolearn.conf b/conf/modules.d/neural_autolearn.conf
new file mode 100644 (file)
index 0000000..846fd7d
--- /dev/null
@@ -0,0 +1,83 @@
+# Neural Autolearn Configuration
+#
+# This configuration is part of the neural plugin and controls automatic
+# training for neural networks with LLM providers.
+#
+# The autolearn section can be added to any neural rule configuration.
+# It uses expression-based conditions for strong confidence learning.
+#
+# Documentation: doc/neural-llm-embeddings-guide.md
+
+# Example autolearn configuration for neural rules
+# Add this inside a neural rule:
+#
+# neural {
+#   rules {
+#     llm_classifier {
+#       providers = [
+#         { type = "llm"; llm_type = "ollama"; model = "nomic-embed-text"; }
+#       ];
+#
+#       # Autolearn configuration
+#       autolearn {
+#         enabled = true;
+#
+#         # Score thresholds
+#         spam_score = 15.0;     # Learn spam if score >= 15.0
+#         ham_score = -5.0;      # Learn ham if score <= -5.0
+#
+#         # Action requirements (optional, more restrictive)
+#         spam_action = "reject";
+#         ham_action = "no action";
+#
+#         # Expression-based conditions (rspamd_expression syntax)
+#         # These provide fine-grained control over learning decisions
+#         spam_condition = "BAYES_SPAM & (DMARC_POLICY_REJECT | RBL_SPAMHAUS_SBL)";
+#         ham_condition = "BAYES_HAM & DKIM_VALID_AU & SPF_PASS";
+#
+#         # Required symbols (all must be present)
+#         spam_symbols = ["BAYES_SPAM"];
+#         ham_symbols = ["BAYES_HAM", "DKIM_VALID"];
+#
+#         # Forbidden symbols (any blocks learning)
+#         skip_symbols = ["WHITELIST_SENDER", "GREYLIST"];
+#
+#         # Symbol weight thresholds
+#         spam_symbol_weight = 5.0;   # Sum of spam_symbols scores >= 5.0
+#         ham_symbol_weight = -3.0;   # Sum of ham_symbols scores <= -3.0
+#
+#         # Probabilistic sampling (reduce training volume)
+#         sampling {
+#           spam_prob = 0.5;   # Learn 50% of qualifying spam
+#           ham_prob = 0.5;
+#         }
+#
+#         # Skip local/authenticated messages
+#         check_local = true;
+#         check_authed = true;
+#       }
+#     }
+#   }
+# }
+
+# Conservative production example:
+# autolearn {
+#   enabled = true;
+#   spam_score = 20.0;
+#   ham_score = -8.0;
+#   spam_action = "reject";
+#   spam_condition = "BAYES_SPAM & !WHITELIST_SENDER";
+#   ham_condition = "BAYES_HAM & DKIM_VALID_AU";
+#   skip_symbols = ["GREYLIST", "RATELIMITED"];
+#   sampling {
+#     spam_prob = 0.3;
+#     ham_prob = 0.3;
+#   }
+# }
+
+# Aggressive initial training example:
+# autolearn {
+#   enabled = true;
+#   spam_score = 10.0;
+#   ham_score = -2.0;
+# }
diff --git a/contrib/neural-embedding-service/Dockerfile b/contrib/neural-embedding-service/Dockerfile
new file mode 100644 (file)
index 0000000..4083fb8
--- /dev/null
@@ -0,0 +1,58 @@
+# Rspamd Neural Embedding Service
+#
+# CPU-optimized embedding service using FastEmbed + ONNX Runtime
+#
+# Build:
+#   docker build -t rspamd-embedding-service .
+#
+# Run:
+#   docker run -p 8080:8080 rspamd-embedding-service
+#
+# With custom model:
+#   docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service
+
+FROM python:3.11-slim
+
+# Build arguments
+ARG EMBEDDING_MODEL="BAAI/bge-small-en-v1.5"
+
+# Environment
+ENV PYTHONUNBUFFERED=1 \
+    PYTHONDONTWRITEBYTECODE=1 \
+    PIP_NO_CACHE_DIR=1 \
+    PIP_DISABLE_PIP_VERSION_CHECK=1 \
+    EMBEDDING_MODEL=${EMBEDDING_MODEL} \
+    EMBEDDING_PORT=8080 \
+    EMBEDDING_HOST=0.0.0.0
+
+# Install system dependencies for ONNX Runtime
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    libgomp1 \
+    && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /app
+
+# Install Python dependencies
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Pre-download model during build (optional, makes container larger but faster startup)
+# Uncomment to include model in image:
+# RUN python -c "from fastembed import TextEmbedding; TextEmbedding('${EMBEDDING_MODEL}')"
+
+# Copy application
+COPY embedding_service.py .
+
+# Non-root user
+RUN useradd -m -u 1000 embedding && chown -R embedding:embedding /app
+USER embedding
+
+# Expose port
+EXPOSE 8080
+
+# Health check
+HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
+    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"
+
+# Run with uvicorn
+CMD ["python", "embedding_service.py"]
diff --git a/contrib/neural-embedding-service/README.md b/contrib/neural-embedding-service/README.md
new file mode 100644 (file)
index 0000000..5d1635a
--- /dev/null
@@ -0,0 +1,187 @@
+# Rspamd Neural Embedding Service
+
+A lightweight, CPU-optimized embedding service for Rspamd's neural plugin.
+
+## Overview
+
+This service provides text embeddings for Rspamd's neural LLM provider. It uses FastEmbed with ONNX Runtime for efficient CPU inference, making it suitable for servers without GPU hardware.
+
+## Features
+
+- **CPU-optimized**: Uses ONNX Runtime with INT8 quantization support
+- **Lightweight**: ~100MB memory for bge-small-en-v1.5
+- **Compatible**: Supports both Ollama and OpenAI API formats
+- **Fast**: 2,500-5,000 sentences/second on modern CPUs
+
+## Quick Start
+
+### Option 1: Direct Python
+
+```bash
+# Install dependencies
+pip install -r requirements.txt
+
+# Run service
+python embedding_service.py
+```
+
+### Option 2: Docker
+
+```bash
+# Build
+docker build -t rspamd-embedding-service .
+
+# Run
+docker run -p 8080:8080 rspamd-embedding-service
+
+# With custom model
+docker run -p 8080:8080 -e EMBEDDING_MODEL="BAAI/bge-base-en-v1.5" rspamd-embedding-service
+```
+
+### Option 3: Docker Compose
+
+See the main guide: `doc/neural-llm-embeddings-guide.md`
+
+## Configuration
+
+### Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `EMBEDDING_MODEL` | `BAAI/bge-small-en-v1.5` | FastEmbed model name |
+| `EMBEDDING_PORT` | `8080` | Port to listen on |
+| `EMBEDDING_HOST` | `0.0.0.0` | Host to bind to |
+
+### Command Line Arguments
+
+```bash
+python embedding_service.py --help
+
+Options:
+  --model, -m     Model name (default: BAAI/bge-small-en-v1.5)
+  --port, -p      Port number (default: 8080)
+  --host, -H      Host to bind (default: 0.0.0.0)
+  --workers, -w   Number of workers (default: 1)
+  --log-level, -l Log level (default: info)
+```
+
+## API Endpoints
+
+### Ollama Format (Recommended for Rspamd)
+
+```bash
+curl http://localhost:8080/api/embeddings -d '{
+  "model": "BAAI/bge-small-en-v1.5",
+  "prompt": "Test message about cheap medications"
+}'
+```
+
+Response:
+```json
+{
+  "embedding": [0.123, -0.456, ...]
+}
+```
+
+### OpenAI Format
+
+```bash
+curl http://localhost:8080/v1/embeddings -d '{
+  "model": "BAAI/bge-small-en-v1.5",
+  "input": "Test message"
+}'
+```
+
+Response:
+```json
+{
+  "object": "list",
+  "data": [{"embedding": [...], "index": 0}],
+  "model": "BAAI/bge-small-en-v1.5",
+  "usage": {"prompt_tokens": 2, "total_tokens": 2}
+}
+```
+
+### Health Check
+
+```bash
+curl http://localhost:8080/health
+```
+
+## Rspamd Configuration
+
+Configure Rspamd to use this service:
+
+```hcl
+# /etc/rspamd/local.d/neural.conf
+rules {
+  default {
+    providers = [
+      {
+        type = "llm";
+        llm_type = "ollama";  # or "openai"
+        model = "BAAI/bge-small-en-v1.5";
+        url = "http://localhost:8080/api/embeddings";  # or /v1/embeddings
+        timeout = 2.0;
+        cache_ttl = 86400;
+      }
+    ];
+    # ...
+  }
+}
+```
+
+## Supported Models
+
+| Model | Size | Dims | Quality | Speed |
+|-------|------|------|---------|-------|
+| `BAAI/bge-small-en-v1.5` | 33MB | 384 | Good | Excellent |
+| `BAAI/bge-base-en-v1.5` | 440MB | 768 | Better | Good |
+| `sentence-transformers/all-MiniLM-L6-v2` | 90MB | 384 | Fair | Excellent |
+| `intfloat/e5-small-v2` | 200MB | 384 | Good | Excellent |
+
+For the full list, see: https://qdrant.github.io/fastembed/examples/Supported_Models/
+
+## Production Deployment
+
+### With Gunicorn
+
+```bash
+pip install gunicorn
+gunicorn -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8080 embedding_service:app
+```
+
+### Resource Recommendations
+
+| Model | Memory | CPU Cores |
+|-------|--------|-----------|
+| bge-small-en-v1.5 | 256MB | 1 |
+| bge-base-en-v1.5 | 1GB | 2 |
+
+## Troubleshooting
+
+### Model download issues
+
+```bash
+# Pre-download model
+python -c "from fastembed import TextEmbedding; TextEmbedding('BAAI/bge-small-en-v1.5')"
+```
+
+### Memory issues
+
+Use a smaller model:
+```bash
+EMBEDDING_MODEL="BAAI/bge-small-en-v1.5" python embedding_service.py
+```
+
+### Slow inference
+
+- Ensure ONNX Runtime is using optimized providers
+- Consider increasing workers for parallel processing
+- Use batching for bulk operations
+
+## License
+
+Apache License 2.0
+
+See the main Rspamd repository for license details.
diff --git a/contrib/neural-embedding-service/embedding_service.py b/contrib/neural-embedding-service/embedding_service.py
new file mode 100644 (file)
index 0000000..50dacee
--- /dev/null
@@ -0,0 +1,299 @@
+#!/usr/bin/env python3
+"""
+Lightweight embedding service for Rspamd neural plugin.
+
+Uses FastEmbed with ONNX for CPU-optimized inference.
+Provides both Ollama-compatible and OpenAI-compatible endpoints.
+
+Installation:
+    pip install fastapi uvicorn fastembed pydantic
+
+Usage:
+    python embedding_service.py [--model MODEL] [--port PORT] [--host HOST]
+
+    # Default: bge-small-en-v1.5 on port 8080
+    python embedding_service.py
+
+    # Custom model
+    python embedding_service.py --model "BAAI/bge-base-en-v1.5"
+
+    # Production with gunicorn
+    gunicorn -w 4 -k uvicorn.workers.UvicornWorker embedding_service:app
+
+Environment variables:
+    EMBEDDING_MODEL: Model name (default: BAAI/bge-small-en-v1.5)
+    EMBEDDING_PORT: Port number (default: 8080)
+    EMBEDDING_HOST: Host to bind (default: 0.0.0.0)
+"""
+
+import argparse
+import logging
+import os
+import time
+from contextlib import asynccontextmanager
+from typing import List, Optional, Union
+
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+
+# FastEmbed - CPU-optimized ONNX inference
+from fastembed import TextEmbedding
+
+# Logging setup
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+# Configuration from environment
+DEFAULT_MODEL = os.getenv('EMBEDDING_MODEL', 'BAAI/bge-small-en-v1.5')
+DEFAULT_PORT = int(os.getenv('EMBEDDING_PORT', '8080'))
+DEFAULT_HOST = os.getenv('EMBEDDING_HOST', '0.0.0.0')
+
+# Global model instance
+model: Optional[TextEmbedding] = None
+model_name: str = DEFAULT_MODEL
+model_dim: int = 0
+
+
+# Request/Response models
+class OllamaEmbeddingRequest(BaseModel):
+    """Ollama-compatible embedding request."""
+    model: str = DEFAULT_MODEL
+    prompt: str
+
+
+class OllamaEmbeddingResponse(BaseModel):
+    """Ollama-compatible embedding response."""
+    embedding: List[float]
+
+
+class OpenAIEmbeddingRequest(BaseModel):
+    """OpenAI-compatible embedding request."""
+    model: str = DEFAULT_MODEL
+    input: Union[str, List[str]]
+    encoding_format: str = "float"
+
+
+class OpenAIEmbeddingData(BaseModel):
+    """OpenAI embedding data object."""
+    object: str = "embedding"
+    embedding: List[float]
+    index: int
+
+
+class OpenAIUsage(BaseModel):
+    """OpenAI usage object."""
+    prompt_tokens: int
+    total_tokens: int
+
+
+class OpenAIEmbeddingResponse(BaseModel):
+    """OpenAI-compatible embedding response."""
+    object: str = "list"
+    data: List[OpenAIEmbeddingData]
+    model: str
+    usage: OpenAIUsage
+
+
+class HealthResponse(BaseModel):
+    """Health check response."""
+    status: str
+    model: str
+    dimensions: int
+    uptime_seconds: float
+
+
+# Startup time for uptime calculation
+startup_time: float = 0.0
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    """Load model on startup."""
+    global model, model_name, model_dim, startup_time
+
+    logger.info(f"Loading embedding model: {model_name}")
+    start = time.time()
+
+    try:
+        model = TextEmbedding(model_name)
+        # Get embedding dimension from a test inference
+        test_embed = list(model.embed(["test"]))[0]
+        model_dim = len(test_embed)
+        elapsed = time.time() - start
+        logger.info(f"Model loaded in {elapsed:.2f}s, dimensions: {model_dim}")
+        startup_time = time.time()
+    except Exception as e:
+        logger.error(f"Failed to load model: {e}")
+        raise
+
+    yield
+
+    logger.info("Shutting down embedding service")
+
+
+app = FastAPI(
+    title="Rspamd Embedding Service",
+    description="CPU-optimized embedding service for Rspamd neural plugin",
+    version="1.0.0",
+    lifespan=lifespan,
+)
+
+
+def get_embedding(text: str) -> List[float]:
+    """Generate embedding for a single text."""
+    if model is None:
+        raise HTTPException(500, "Model not loaded")
+
+    embeddings = list(model.embed([text]))
+    return embeddings[0].tolist()
+
+
+def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
+    """Generate embeddings for multiple texts."""
+    if model is None:
+        raise HTTPException(500, "Model not loaded")
+
+    embeddings = list(model.embed(texts))
+    return [e.tolist() for e in embeddings]
+
+
+def count_tokens(text: str) -> int:
+    """Approximate token count (words * 1.3)."""
+    return int(len(text.split()) * 1.3)
+
+
+@app.post("/api/embeddings", response_model=OllamaEmbeddingResponse)
+async def ollama_embeddings(request: OllamaEmbeddingRequest) -> OllamaEmbeddingResponse:
+    """
+    Ollama-compatible embedding endpoint.
+
+    Used by Rspamd neural LLM provider with llm_type = "ollama".
+    """
+    if not request.prompt:
+        raise HTTPException(400, "Missing prompt")
+
+    logger.debug(f"Ollama request: {len(request.prompt)} chars")
+    embedding = get_embedding(request.prompt)
+
+    return OllamaEmbeddingResponse(embedding=embedding)
+
+
+@app.post("/v1/embeddings", response_model=OpenAIEmbeddingResponse)
+async def openai_embeddings(request: OpenAIEmbeddingRequest) -> OpenAIEmbeddingResponse:
+    """
+    OpenAI-compatible embedding endpoint.
+
+    Used by Rspamd neural LLM provider with llm_type = "openai".
+    """
+    if not request.input:
+        raise HTTPException(400, "Missing input")
+
+    # Handle single string or list of strings
+    if isinstance(request.input, str):
+        texts = [request.input]
+    else:
+        texts = request.input
+
+    logger.debug(f"OpenAI request: {len(texts)} texts")
+    embeddings = get_embeddings_batch(texts)
+
+    # Build response
+    data = [
+        OpenAIEmbeddingData(embedding=emb, index=i)
+        for i, emb in enumerate(embeddings)
+    ]
+
+    total_tokens = sum(count_tokens(t) for t in texts)
+
+    return OpenAIEmbeddingResponse(
+        data=data,
+        model=request.model,
+        usage=OpenAIUsage(prompt_tokens=total_tokens, total_tokens=total_tokens)
+    )
+
+
+@app.get("/health", response_model=HealthResponse)
+@app.get("/", response_model=HealthResponse)
+async def health() -> HealthResponse:
+    """Health check endpoint."""
+    return HealthResponse(
+        status="ok" if model is not None else "loading",
+        model=model_name,
+        dimensions=model_dim,
+        uptime_seconds=time.time() - startup_time if startup_time > 0 else 0
+    )
+
+
+@app.get("/v1/models")
+async def list_models():
+    """List available models (OpenAI-compatible)."""
+    return {
+        "object": "list",
+        "data": [
+            {
+                "id": model_name,
+                "object": "model",
+                "created": int(startup_time),
+                "owned_by": "fastembed",
+                "permission": [],
+                "root": model_name,
+            }
+        ]
+    }
+
+
+def main():
+    """Run the embedding service."""
+    global model_name
+
+    parser = argparse.ArgumentParser(
+        description="Rspamd embedding service",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument(
+        "--model", "-m",
+        default=DEFAULT_MODEL,
+        help="FastEmbed model name"
+    )
+    parser.add_argument(
+        "--port", "-p",
+        type=int,
+        default=DEFAULT_PORT,
+        help="Port to listen on"
+    )
+    parser.add_argument(
+        "--host", "-H",
+        default=DEFAULT_HOST,
+        help="Host to bind to"
+    )
+    parser.add_argument(
+        "--workers", "-w",
+        type=int,
+        default=1,
+        help="Number of worker processes"
+    )
+    parser.add_argument(
+        "--log-level", "-l",
+        default="info",
+        choices=["debug", "info", "warning", "error"],
+        help="Log level"
+    )
+
+    args = parser.parse_args()
+    model_name = args.model
+
+    import uvicorn
+    uvicorn.run(
+        "embedding_service:app",
+        host=args.host,
+        port=args.port,
+        workers=args.workers,
+        log_level=args.log_level,
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/contrib/neural-embedding-service/requirements.txt b/contrib/neural-embedding-service/requirements.txt
new file mode 100644 (file)
index 0000000..da81649
--- /dev/null
@@ -0,0 +1,19 @@
+# Rspamd Neural Embedding Service Dependencies
+#
+# Install: pip install -r requirements.txt
+
+# FastAPI web framework
+fastapi>=0.100.0
+
+# ASGI server
+uvicorn[standard]>=0.23.0
+
+# FastEmbed - CPU-optimized embedding inference via ONNX
+# Includes: onnxruntime, numpy, tokenizers
+fastembed>=0.2.0
+
+# Data validation
+pydantic>=2.0.0
+
+# Optional: production ASGI server
+# gunicorn>=21.0.0
diff --git a/lualib/lua_neural_learn.lua b/lualib/lua_neural_learn.lua
new file mode 100644 (file)
index 0000000..c04f6c5
--- /dev/null
@@ -0,0 +1,491 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+Neural network autolearn helpers.
+
+This module provides configurable autolearn conditions for neural networks,
+particularly useful for LLM-based providers where automatic learning needs
+careful control.
+
+Similar to lua_bayes_learn.lua, this provides:
+- Guards system for pluggable checks
+- Expression-based conditions (rspamd_expression)
+- Score/action/symbol-based thresholds
+- Hooks for custom logic via rspamd_plugins
+
+Usage in neural.lua:
+  local neural_learn = require "lua_neural_learn"
+  local can_learn, reason = neural_learn.can_autolearn(task, rule, 'spam')
+]]--
+
+local lua_util = require "lua_util"
+local rspamd_expression = require "rspamd_expression"
+local rspamd_logger = require "rspamd_logger"
+
+local N = "lua_neural_learn"
+
+local exports = {}
+
+-- Global defaults that can be overridden via configure()
+local global_defaults = {}
+
+-- Registered guards (callbacks that can block learning)
+local autolearn_guards = {}
+
+-- Cached compiled expressions per rule
+local expression_cache = {}
+
+-- Default autolearn settings
+local default_autolearn_settings = {
+  -- Master enable/disable
+  enabled = false,
+
+  -- Require minimum score magnitude for learning
+  spam_score = nil,        -- e.g., 6.0 - learn spam if score >= 6.0
+  ham_score = nil,         -- e.g., -2.0 - learn ham if score <= -2.0
+
+  -- Require specific actions
+  spam_action = nil,       -- e.g., 'reject' - only learn spam on reject
+  ham_action = nil,        -- e.g., 'no action' - only learn ham on no action
+
+  -- Expression-based conditions (rspamd_expression syntax)
+  -- Examples:
+  --   "BAYES_SPAM & !WHITELIST_SENDER"
+  --   "DMARC_POLICY_REJECT | (RBL_SPAMHAUS_SBL & SURBL_MULTI)"
+  spam_condition = nil,
+  ham_condition = nil,
+
+  -- Required symbols (all must be present)
+  spam_symbols = nil,      -- e.g., {'BAYES_SPAM', 'DKIM_VALID'}
+  ham_symbols = nil,
+
+  -- Forbidden symbols (any blocks learning)
+  skip_symbols = nil,      -- e.g., {'WHITELIST_SENDER', 'GREYLIST'}
+
+  -- Minimum symbol weight sum
+  spam_symbol_weight = nil, -- e.g., 5.0 - sum of spam_symbols scores >= 5.0
+  ham_symbol_weight = nil,  -- e.g., -3.0 - sum of ham_symbols scores <= -3.0
+
+  -- Probability-based check (skip if already confident)
+  probability_check = {
+    enabled = false,
+    variable = 'neural_prob',  -- mempool variable name
+    spam_min = 0.95,           -- skip if already 95% spam
+    ham_max = 0.05,            -- skip if already 95% ham
+  },
+
+  -- Rate limiting
+  rate_limit = {
+    enabled = false,
+    max_daily = 1000,          -- per class per day
+    redis_prefix = 'neural_autolearn',
+  },
+
+  -- Sampling (probabilistic training reduction)
+  sampling = {
+    spam_prob = 1.0,           -- 1.0 = always, 0.5 = 50% chance
+    ham_prob = 1.0,
+  },
+
+  -- Skip conditions
+  check_local = false,         -- skip local network messages
+  check_authed = false,        -- skip authenticated users
+}
+
+-- Helper: convert array to set
+local function as_set(tbl)
+  if not tbl then
+    return nil
+  end
+  local res = {}
+  for _, v in ipairs(tbl) do
+    res[v] = true
+  end
+  return res
+end
+
+-- Helper: merge options with defaults
+local function merge_options(defaults, overrides)
+  local merged = lua_util.override_defaults(defaults, global_defaults)
+  if overrides then
+    merged = lua_util.override_defaults(merged, overrides)
+  end
+  return merged
+end
+
+-- Guard execution
+local function execute_guards(task, learn_type, ctx)
+  for _, guard in ipairs(autolearn_guards) do
+    local ok, reason = guard.cb(task, learn_type, ctx)
+    if not ok then
+      return false, reason or guard.name
+    end
+  end
+  return true
+end
+
+--- Register a guard callback for autolearn decisions
+-- @param name string guard name
+-- @param cb function(task, learn_type, ctx) -> bool, reason
+-- @param opts table optional {priority = number}
+function exports.register_guard(name, cb, opts)
+  if type(name) == 'function' then
+    cb = name
+    name = string.format('guard_%d', #autolearn_guards + 1)
+  end
+
+  if type(cb) ~= 'function' then
+    rspamd_logger.errx(rspamd_config, '%s: guard callback must be a function', N)
+    return nil
+  end
+
+  local guard = {
+    name = name,
+    cb = cb,
+    priority = opts and opts.priority or 0,
+  }
+
+  autolearn_guards[#autolearn_guards + 1] = guard
+  table.sort(autolearn_guards, function(a, b)
+    return (a.priority or 0) > (b.priority or 0)
+  end)
+
+  lua_util.debugm(N, rspamd_config, 'registered autolearn guard: %s', name)
+  return name
+end
+
+--- Unregister a guard by name
+function exports.unregister_guard(name)
+  for i = #autolearn_guards, 1, -1 do
+    if autolearn_guards[i].name == name then
+      table.remove(autolearn_guards, i)
+      return true
+    end
+  end
+  return false
+end
+
+--- Configure global defaults
+-- @param opts table of default overrides
+function exports.configure(opts)
+  if opts then
+    global_defaults = lua_util.override_defaults(global_defaults, opts)
+    lua_util.debugm(N, rspamd_config, 'configured neural autolearn defaults')
+  end
+end
+
+-- Compile and cache expression
+local function get_expression(rule_name, expr_str, pool)
+  local cache_key = rule_name .. ':' .. expr_str
+  if expression_cache[cache_key] then
+    return expression_cache[cache_key]
+  end
+
+  local function parse_atom(str)
+    local atom = ''
+    for c in str:gmatch('.') do
+      if c:match('[%w_]') then
+        atom = atom .. c
+      else
+        break
+      end
+    end
+    return atom
+  end
+
+  local function process_atom(atom, task)
+    if task:has_symbol(atom) then
+      local sym = task:get_symbol(atom)
+      if sym and sym[1] then
+        local score = math.abs(sym[1].score or 0)
+        return score > 0.001 and score or 0.001
+      end
+      return 0.001
+    end
+    return 0
+  end
+
+  local expr, err = rspamd_expression.create(expr_str, { parse_atom, process_atom }, pool)
+  if err then
+    rspamd_logger.errx(rspamd_config, '%s: cannot create expression [%s]: %s', N, expr_str, err)
+    return nil
+  end
+
+  expression_cache[cache_key] = expr
+  return expr
+end
+
+-- Check if all required symbols are present
+local function check_required_symbols(task, symbols)
+  if not symbols or #symbols == 0 then
+    return true
+  end
+  for _, sym in ipairs(symbols) do
+    if not task:has_symbol(sym) then
+      return false, string.format('missing required symbol: %s', sym)
+    end
+  end
+  return true
+end
+
+-- Check if any forbidden symbols are present
+local function check_forbidden_symbols(task, symbols)
+  if not symbols then
+    return true
+  end
+  local skip_set = as_set(symbols)
+  if not skip_set then
+    return true
+  end
+  for sym, _ in pairs(skip_set) do
+    if task:has_symbol(sym) then
+      return false, string.format('has forbidden symbol: %s', sym)
+    end
+  end
+  return true
+end
+
+-- Calculate sum of symbol scores
+local function get_symbols_weight(task, symbols)
+  if not symbols or #symbols == 0 then
+    return 0
+  end
+  local total = 0
+  for _, sym in ipairs(symbols) do
+    local s = task:get_symbol(sym)
+    if s and s[1] then
+      total = total + (s[1].score or 0)
+    end
+  end
+  return total
+end
+
+--- Main function: determine if a message should be autolearned
+-- @param task rspamd_task
+-- @param rule neural rule configuration
+-- @param learn_type 'spam' or 'ham'
+-- @param overrides optional per-call config overrides
+-- @return bool can_learn, string reason
+function exports.can_autolearn(task, rule, learn_type, overrides)
+  local autolearn_opts = rule.autolearn or {}
+  local opts = merge_options(default_autolearn_settings, autolearn_opts)
+
+  if overrides then
+    opts = merge_options(opts, overrides)
+  end
+
+  -- Master enable check
+  if not opts.enabled then
+    return false, 'autolearn disabled'
+  end
+
+  local score = task:get_metric_score()[1]
+  local action = task:get_metric_action()
+
+  local ctx = {
+    task = task,
+    rule = rule,
+    learn_type = learn_type,
+    score = score,
+    action = action,
+    options = opts,
+  }
+
+  -- Execute registered guards first
+  local guard_ok, guard_reason = execute_guards(task, learn_type, ctx)
+  if not guard_ok then
+    return false, string.format('blocked by guard: %s', guard_reason)
+  end
+
+  -- Skip checks
+  if opts.check_local and task:get_from_ip() and task:get_from_ip():is_local() then
+    return false, 'local network message'
+  end
+
+  if opts.check_authed and task:get_user() then
+    return false, 'authenticated user'
+  end
+
+  -- Forbidden symbols check
+  local skip_ok, skip_reason = check_forbidden_symbols(task, opts.skip_symbols)
+  if not skip_ok then
+    return false, skip_reason
+  end
+
+  -- Learn type specific checks
+  if learn_type == 'spam' then
+    -- Score threshold
+    if opts.spam_score and score < opts.spam_score then
+      return false, string.format('score %.2f < spam_score %.2f', score, opts.spam_score)
+    end
+
+    -- Action requirement
+    if opts.spam_action and action ~= opts.spam_action then
+      return false, string.format('action %s != required %s', action, opts.spam_action)
+    end
+
+    -- Required symbols
+    local sym_ok, sym_reason = check_required_symbols(task, opts.spam_symbols)
+    if not sym_ok then
+      return false, sym_reason
+    end
+
+    -- Symbol weight threshold
+    if opts.spam_symbol_weight then
+      local weight = get_symbols_weight(task, opts.spam_symbols)
+      if weight < opts.spam_symbol_weight then
+        return false, string.format('spam symbol weight %.2f < %.2f', weight, opts.spam_symbol_weight)
+      end
+    end
+
+    -- Expression condition
+    if opts.spam_condition then
+      local expr = get_expression(rule.prefix or 'default', opts.spam_condition, rspamd_config:get_mempool())
+      if expr then
+        local result = expr:process(task)
+        if result <= 0 then
+          return false, string.format('spam_condition not satisfied: %s', opts.spam_condition)
+        end
+      end
+    end
+
+  elseif learn_type == 'ham' then
+    -- Score threshold
+    if opts.ham_score and score > opts.ham_score then
+      return false, string.format('score %.2f > ham_score %.2f', score, opts.ham_score)
+    end
+
+    -- Action requirement
+    if opts.ham_action and action ~= opts.ham_action then
+      return false, string.format('action %s != required %s', action, opts.ham_action)
+    end
+
+    -- Required symbols
+    local sym_ok, sym_reason = check_required_symbols(task, opts.ham_symbols)
+    if not sym_ok then
+      return false, sym_reason
+    end
+
+    -- Symbol weight threshold
+    if opts.ham_symbol_weight then
+      local weight = get_symbols_weight(task, opts.ham_symbols)
+      if weight > opts.ham_symbol_weight then
+        return false, string.format('ham symbol weight %.2f > %.2f', weight, opts.ham_symbol_weight)
+      end
+    end
+
+    -- Expression condition
+    if opts.ham_condition then
+      local expr = get_expression(rule.prefix or 'default', opts.ham_condition, rspamd_config:get_mempool())
+      if expr then
+        local result = expr:process(task)
+        if result <= 0 then
+          return false, string.format('ham_condition not satisfied: %s', opts.ham_condition)
+        end
+      end
+    end
+  end
+
+  -- Probability check (skip if already confident)
+  if opts.probability_check and opts.probability_check.enabled then
+    local prob_var = opts.probability_check.variable or 'neural_prob'
+    local prob = task:get_mempool():get_variable(prob_var, 'double')
+    if prob then
+      if learn_type == 'spam' and prob >= opts.probability_check.spam_min then
+        return false, string.format('already confident spam: %.2f >= %.2f', prob, opts.probability_check.spam_min)
+      elseif learn_type == 'ham' and prob <= opts.probability_check.ham_max then
+        return false, string.format('already confident ham: %.2f <= %.2f', prob, opts.probability_check.ham_max)
+      end
+    end
+  end
+
+  -- Probabilistic sampling
+  if opts.sampling then
+    local sample_prob = learn_type == 'spam' and opts.sampling.spam_prob or opts.sampling.ham_prob
+    if sample_prob and sample_prob < 1.0 then
+      local coin = math.random()
+      if coin > sample_prob then
+        return false, string.format('sampled out: %.2f > %.2f', coin, sample_prob)
+      end
+    end
+  end
+
+  return true, nil
+end
+
+--- Determine learn type based on score/action/symbols
+-- @param task rspamd_task
+-- @param rule neural rule configuration
+-- @return string learn_type ('spam', 'ham', or nil), string reason
+function exports.get_learn_type(task, rule)
+  local autolearn_opts = rule.autolearn or {}
+  local opts = merge_options(default_autolearn_settings, autolearn_opts)
+
+  if not opts.enabled then
+    return nil, 'autolearn disabled'
+  end
+
+  -- Try spam first
+  local spam_ok, spam_reason = exports.can_autolearn(task, rule, 'spam')
+  if spam_ok then
+    return 'spam', 'autolearn spam'
+  end
+
+  -- Try ham
+  local ham_ok, ham_reason = exports.can_autolearn(task, rule, 'ham')
+  if ham_ok then
+    return 'ham', 'autolearn ham'
+  end
+
+  -- Neither qualifies
+  return nil, spam_reason or ham_reason or 'no autolearn condition matched'
+end
+
+--- Set autolearn class in mempool (for integration with neural.lua)
+-- @param task rspamd_task
+-- @param learn_type 'spam' or 'ham'
+function exports.set_autolearn_class(task, learn_type)
+  task:get_mempool():set_variable('neural_autolearn_class', learn_type)
+  lua_util.debugm(N, task, 'set neural autolearn class: %s', learn_type)
+end
+
+--- Get autolearn class from mempool
+-- @param task rspamd_task
+-- @return string learn_type or nil
+function exports.get_autolearn_class(task)
+  return task:get_mempool():get_variable('neural_autolearn_class')
+end
+
+--- Clear expression cache (useful for config reload)
+function exports.clear_cache()
+  expression_cache = {}
+end
+
+-- Register module in rspamd_plugins for user hooks
+if rspamd_plugins then
+  rspamd_plugins['neural_learn'] = {
+    register_guard = exports.register_guard,
+    unregister_guard = exports.unregister_guard,
+    configure = exports.configure,
+    can_autolearn = exports.can_autolearn,
+    get_learn_type = exports.get_learn_type,
+    set_autolearn_class = exports.set_autolearn_class,
+    get_autolearn_class = exports.get_autolearn_class,
+  }
+end
+
+return exports
index 8d9859f8a9a0ac0e882221b7bd8d3dc4c7f20eef..44450c9d690a0e97adeb61ceb1f3aaf698b4fcaf 100644 (file)
@@ -20,6 +20,7 @@ local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
 local lua_verdict = require "lua_verdict"
 local neural_common = require "plugins/neural"
+local neural_learn = require "lua_neural_learn"
 local rspamd_kann = require "rspamd_kann"
 local rspamd_logger = require "rspamd_logger"
 local rspamd_tensor = require "rspamd_tensor"
@@ -217,20 +218,55 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     end
   end
 
-  -- If LLM provider is configured, suppress autotrain unless manual training requested
-  if not manual_train and rule.providers and #rule.providers > 0 then
+  -- Check for autolearn class set by mempool (integration with external learning decisions)
+  if not manual_train then
+    local autolearn_class = neural_learn.get_autolearn_class(task)
+    if autolearn_class then
+      lua_util.debugm(N, task, 'found neural autolearn class in mempool: %s', autolearn_class)
+      if autolearn_class == 'spam' then
+        learn_spam = true
+        manual_train = true
+      elseif autolearn_class == 'ham' then
+        learn_ham = true
+        manual_train = true
+      end
+    end
+  end
+
+  -- If LLM provider is configured, use autolearn conditions instead of simple score thresholds
+  local has_llm_provider = false
+  if rule.providers and #rule.providers > 0 then
     for _, p in ipairs(rule.providers) do
       if p.type == 'llm' then
-        lua_util.debugm(N, task, 'suppress autotrain: llm provider present and no manual header')
-        learn_spam = false
-        learn_ham = false
-        skip_reason = 'llm provider requires manual training'
+        has_llm_provider = true
         break
       end
     end
   end
 
-  if not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
+  if has_llm_provider and not manual_train then
+    -- Use expression-based autolearn conditions for LLM providers
+    if rule.autolearn and rule.autolearn.enabled then
+      local learn_type, reason = neural_learn.get_learn_type(task, rule)
+      if learn_type == 'spam' then
+        learn_spam = true
+        lua_util.debugm(N, task, 'autolearn spam via expression: %s', reason)
+      elseif learn_type == 'ham' then
+        learn_ham = true
+        lua_util.debugm(N, task, 'autolearn ham via expression: %s', reason)
+      else
+        skip_reason = reason or 'autolearn condition not met'
+        lua_util.debugm(N, task, 'autolearn skip: %s', skip_reason)
+      end
+    else
+      -- LLM provider without autolearn config - require manual training
+      learn_spam = false
+      learn_ham = false
+      skip_reason = 'llm provider requires autolearn config or manual training'
+      lua_util.debugm(N, task, 'suppress autotrain: llm provider present, no autolearn config')
+    end
+  elseif not manual_train and (not train_opts.store_pool_only and train_opts.autotrain) then
+    -- Traditional score/verdict based learning for non-LLM providers
     if train_opts.spam_score then
       learn_spam = score >= train_opts.spam_score
 
@@ -261,8 +297,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
           verdict)
       end
     end
-  else
-    if train_opts.store_pool_only and not manual_train then
+  elseif not manual_train then
+    if train_opts.store_pool_only then
       local ucl = require "ucl"
       learn_ham = false
       learn_spam = false
@@ -1120,3 +1156,28 @@ for _, rule in pairs(settings.rules) do
     end
   end)
 end
+
+-- Register plugin API in rspamd_plugins for user hooks
+if rspamd_plugins then
+  rspamd_plugins['neural'] = rspamd_plugins['neural'] or {}
+  -- Expose autolearn hooks for user customization
+  rspamd_plugins['neural'].autolearn = {
+    -- Register a custom guard that can block learning
+    -- cb: function(task, learn_type, ctx) -> bool, reason
+    register_guard = neural_learn.register_guard,
+    -- Remove a registered guard
+    unregister_guard = neural_learn.unregister_guard,
+    -- Configure global autolearn defaults
+    configure = neural_learn.configure,
+    -- Check if task qualifies for autolearn
+    -- Returns: can_learn (bool), reason (string)
+    can_autolearn = neural_learn.can_autolearn,
+    -- Get learn type for task based on conditions
+    -- Returns: 'spam', 'ham', or nil
+    get_learn_type = neural_learn.get_learn_type,
+    -- Set autolearn class in mempool (triggers learning in idempotent callback)
+    set_autolearn_class = neural_learn.set_autolearn_class,
+    -- Get autolearn class from mempool
+    get_autolearn_class = neural_learn.get_autolearn_class,
+  }
+end