from documents.signals import document_consumption_finished
from documents.signals import document_updated
from documents.signals.handlers import add_inbox_tags
+ from documents.signals.handlers import add_or_update_document_in_llm_index
from documents.signals.handlers import add_to_index
from documents.signals.handlers import run_workflows_added
from documents.signals.handlers import run_workflows_updated
document_consumption_finished.connect(set_storage_path)
document_consumption_finished.connect(add_to_index)
document_consumption_finished.connect(run_workflows_added)
+ document_consumption_finished.connect(add_or_update_document_in_llm_index)
document_updated.connect(run_workflows_updated)
import documents.schema # noqa: F401
from documents.permissions import get_objects_for_user_owner_aware
from documents.permissions import set_permissions_for_object
from documents.templating.workflows import parse_w_workflow_placeholders
+from paperless.config import AIConfig
if TYPE_CHECKING:
from pathlib import Path
task_instance.save()
except Exception: # pragma: no cover
logger.exception("Updating PaperlessTask failed")
+
+
+def add_or_update_document_in_llm_index(sender, document, **kwargs):
+ """
+ Add or update a document in the LLM index when it is created or updated.
+ """
+ ai_config = AIConfig()
+ if ai_config.llm_index_enabled():
+ from documents.tasks import update_document_in_llm_index
+
+ update_document_in_llm_index.delay(document)
+
+
+@receiver(models.signals.post_delete, sender=Document)
+def delete_document_from_llm_index(sender, instance: Document, **kwargs):
+ """
+ Delete a document from the LLM index when it is deleted.
+ """
+ ai_config = AIConfig()
+ if ai_config.llm_index_enabled():
+ from documents.tasks import remove_document_from_llm_index
+
+ remove_document_from_llm_index.delay(instance)
from pathlib import Path
from tempfile import TemporaryDirectory
-import faiss
-import llama_index.core.settings as llama_settings
import tqdm
from celery import Task
from celery import shared_task
from django.db.models.signals import post_save
from django.utils import timezone
from filelock import FileLock
-from llama_index.core import Document as LlamaDocument
-from llama_index.core import StorageContext
-from llama_index.core import VectorStoreIndex
-from llama_index.core.node_parser import SimpleNodeParser
-from llama_index.core.storage.docstore import SimpleDocumentStore
-from llama_index.core.storage.index_store import SimpleIndexStore
-from llama_index.vector_stores.faiss import FaissVectorStore
from whoosh.writing import AsyncWriter
from documents import index
from documents.signals import document_updated
from documents.signals.handlers import cleanup_document_deletion
from documents.signals.handlers import run_workflows
-from paperless.ai.embedding import build_llm_index_text
-from paperless.ai.embedding import get_embedding_dim
-from paperless.ai.embedding import get_embedding_model
+from paperless.ai.indexing import llm_index_add_or_update_document
+from paperless.ai.indexing import llm_index_remove_document
+from paperless.ai.indexing import rebuild_llm_index
+from paperless.config import AIConfig
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
for doc in documents:
index.update_document(writer, doc)
+ ai_config = AIConfig()
+ if ai_config.llm_index_enabled():
+ for doc in documents:
+ llm_index_add_or_update_document()
+
@shared_task
def update_document_content_maybe_archive_file(document_id):
with index.open_index_writer() as writer:
index.update_document(writer, document)
+ ai_config = AIConfig()
+ if ai_config.llm_index_enabled:
+ llm_index_add_or_update_document(document)
+
clear_document_caches(document.pk)
except Exception:
def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
- if rebuild:
- shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
- settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
+ rebuild_llm_index(
+ progress_bar_disable=progress_bar_disable,
+ rebuild=rebuild,
+ )
- embed_model = get_embedding_model()
- llama_settings.Settings.embed_model = embed_model
- if rebuild or not settings.LLM_INDEX_DIR.exists():
- embedding_dim = get_embedding_dim()
- faiss_index = faiss.IndexFlatL2(embedding_dim)
- vector_store = FaissVectorStore(faiss_index=faiss_index)
- else:
- vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
+@shared_task
+def update_document_in_llm_index(document):
+ llm_index_add_or_update_document(document)
- docstore = SimpleDocumentStore()
- index_store = SimpleIndexStore()
- storage_context = StorageContext.from_defaults(
- docstore=docstore,
- index_store=index_store,
- persist_dir=settings.LLM_INDEX_DIR,
- vector_store=vector_store,
- )
+@shared_task
+def remove_document_from_llm_index(document):
+ llm_index_remove_document(document)
- parser = SimpleNodeParser()
- nodes = []
-
- for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
- if not document.content:
- continue
-
- text = build_llm_index_text(document)
- metadata = {
- "document_id": document.id,
- "title": document.title,
- "tags": [t.name for t in document.tags.all()],
- "correspondent": document.correspondent.name
- if document.correspondent
- else None,
- "document_type": document.document_type.name
- if document.document_type
- else None,
- "created": document.created.isoformat() if document.created else None,
- "added": document.added.isoformat() if document.added else None,
- }
-
- doc = LlamaDocument(text=text, metadata=metadata)
- doc_nodes = parser.get_nodes_from_documents([doc])
- nodes.extend(doc_nodes)
-
- index = VectorStoreIndex(
- nodes=nodes,
- storage_context=storage_context,
- embed_model=embed_model,
- )
- index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+# TODO: schedule to run periodically
+@shared_task
+def rebuild_llm_index_task():
+ from paperless.ai.indexing import rebuild_llm_index
+
+ rebuild_llm_index(rebuild=True)
from documents.models import Document
from paperless.ai.client import AIClient
-from paperless.ai.indexing import load_index
+from paperless.ai.indexing import load_or_build_index
logger = logging.getLogger("paperless.ai.chat")
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
- index = load_index()
+ index = load_or_build_index()
doc_ids = [doc.pk for doc in documents]
import logging
+import shutil
+import faiss
import llama_index.core.settings as llama_settings
+import tqdm
from django.conf import settings
+from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
-from llama_index.core import load_index_from_storage
+from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.retrievers import VectorIndexRetriever
+from llama_index.core.schema import BaseNode
+from llama_index.core.storage.docstore import SimpleDocumentStore
+from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document
+from paperless.ai.embedding import build_llm_index_text
+from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
logger = logging.getLogger("paperless.ai.indexing")
-def load_index() -> VectorStoreIndex:
- """Loads the persisted LlamaIndex from disk."""
- vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
- embed_model = get_embedding_model()
+def get_or_create_storage_context(*, rebuild=False):
+ """
+ Loads or creates the StorageContext (vector store, docstore, index store).
+ If rebuild=True, deletes and recreates everything.
+ """
+ if rebuild:
+ shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
+ settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
- llama_settings.Settings.embed_model = embed_model
- llama_settings.Settings.chunk_size = 512
+ if rebuild or not settings.LLM_INDEX_DIR.exists():
+ embedding_dim = get_embedding_dim()
+ faiss_index = faiss.IndexFlatL2(embedding_dim)
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
+ docstore = SimpleDocumentStore()
+ index_store = SimpleIndexStore()
+ else:
+ vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
+ docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
+ index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
- storage_context = StorageContext.from_defaults(
+ return StorageContext.from_defaults(
+ docstore=docstore,
+ index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
)
- return load_index_from_storage(storage_context)
+
+
+def get_vector_store_index(storage_context, embed_model):
+ """
+ Returns a VectorStoreIndex given a storage context and embed model.
+ """
+ return VectorStoreIndex(
+ storage_context=storage_context,
+ embed_model=embed_model,
+ )
+
+
+def build_document_node(document) -> list[BaseNode]:
+ """
+ Given a Document, returns parsed Nodes ready for indexing.
+ """
+ if not document.content:
+ return []
+
+ text = build_llm_index_text(document)
+ metadata = {
+ "document_id": document.id,
+ "title": document.title,
+ "tags": [t.name for t in document.tags.all()],
+ "correspondent": document.correspondent.name
+ if document.correspondent
+ else None,
+ "document_type": document.document_type.name
+ if document.document_type
+ else None,
+ "created": document.created.isoformat() if document.created else None,
+ "added": document.added.isoformat() if document.added else None,
+ }
+ doc = LlamaDocument(text=text, metadata=metadata)
+ parser = SimpleNodeParser()
+ return parser.get_nodes_from_documents([doc])
+
+
+def load_or_build_index(storage_context, embed_model, nodes=None):
+ """
+ Load an existing VectorStoreIndex if present,
+ or build a new one using provided nodes if storage is empty.
+ """
+ try:
+ return VectorStoreIndex(
+ storage_context=storage_context,
+ embed_model=embed_model,
+ )
+ except ValueError as e:
+ if "One of nodes, objects, or index_struct must be provided" in str(e):
+ if not nodes:
+ return None
+ return VectorStoreIndex(
+ nodes=nodes,
+ storage_context=storage_context,
+ embed_model=embed_model,
+ )
+ raise
+
+
+def remove_existing_document_nodes(document, index):
+ """
+ Removes existing documents from docstore for a given document from the index.
+ This is necessary because FAISS IndexFlatL2 is append-only.
+ """
+ all_node_ids = list(index.docstore.docs.keys())
+ existing_nodes = [
+ node.node_id
+ for node in index.docstore.get_nodes(all_node_ids)
+ if node.metadata.get("document_id") == document.id
+ ]
+ for node_id in existing_nodes:
+ # Delete from docstore, FAISS IndexFlatL2 are append-only
+ index.docstore.delete_document(node_id)
+
+
+def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False):
+ """
+ Rebuilds the LLM index from scratch.
+ """
+ embed_model = get_embedding_model()
+ llama_settings.Settings.embed_model = embed_model
+
+ storage_context = get_or_create_storage_context(rebuild=rebuild)
+
+ nodes = []
+
+ for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
+ document_nodes = build_document_node(document)
+ nodes.extend(document_nodes)
+
+ if not nodes:
+ raise RuntimeError(
+ "No nodes to index — check that documents are available and have content.",
+ )
+
+ VectorStoreIndex(
+ nodes=nodes,
+ storage_context=storage_context,
+ embed_model=embed_model,
+ )
+ storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+
+
+def llm_index_add_or_update_document(document):
+ """
+ Adds or updates a document in the LLM index.
+ If the document already exists, it will be replaced.
+ """
+ embed_model = get_embedding_model()
+ llama_settings.Settings.embed_model = embed_model
+
+ storage_context = get_or_create_storage_context(rebuild=False)
+
+ new_nodes = build_document_node(document)
+
+ index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
+
+ if index is None:
+ # Nothing to index
+ return
+
+ # Remove old nodes
+ remove_existing_document_nodes(document, index)
+
+ index.insert_nodes(new_nodes)
+
+ storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+
+
+def llm_index_remove_document(document):
+ embed_model = get_embedding_model()
+ llama_settings.embed_model = embed_model
+
+ storage_context = get_or_create_storage_context(rebuild=False)
+
+ index = load_or_build_index(storage_context, embed_model)
+ if index is None:
+ return # Nothing to remove
+
+ # Remove old nodes
+ remove_existing_document_nodes(document, index)
+
+ storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
- """Runs a similarity query and returns top-k similar Document objects."""
- # Load the index
- index = load_index()
+ """
+ Runs a similarity query and returns top-k similar Document objects.
+ """
+ index = load_or_build_index()
retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
# Build query from the document text
query_text = (document.title or "") + "\n" + (document.content or "")
-
- # Query
results = retriever.retrieve(query_text)
# Each result.node.metadata["document_id"] should match our stored doc
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
self.llm_url = app_config.llm_url or settings.LLM_URL
+
+ def llm_index_enabled(self) -> bool:
+ return (
+ self.ai_enabled
+ and self.llm_embedding_backend
+ and self.llm_embedding_backend
+ )
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
- patch("paperless.ai.chat.load_index") as mock_load_index,
+ patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
- patch("paperless.ai.chat.load_index") as mock_load_index,
+ patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless.ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
def test_stream_chat_no_matching_nodes():
with (
patch("paperless.ai.chat.AIClient") as mock_client_cls,
- patch("paperless.ai.chat.load_index") as mock_load_index,
+ patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
--- /dev/null
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import pytest
+from django.utils import timezone
+from llama_index.core.base.embeddings.base import BaseEmbedding
+
+from documents.models import Document
+from paperless.ai import indexing
+
+
+@pytest.fixture
+def temp_llm_index_dir(tmp_path):
+ original_dir = indexing.settings.LLM_INDEX_DIR
+ indexing.settings.LLM_INDEX_DIR = tmp_path
+ yield tmp_path
+ indexing.settings.LLM_INDEX_DIR = original_dir
+
+
+@pytest.fixture
+def real_document(db):
+ return Document.objects.create(
+ title="Test Document",
+ content="This is some test content.",
+ added=timezone.now(),
+ )
+
+
+@pytest.fixture
+def mock_embed_model():
+ """Mocks the embedding model."""
+ with patch("paperless.ai.indexing.get_embedding_model") as mock:
+ mock.return_value = FakeEmbedding()
+ yield mock
+
+
+class FakeEmbedding(BaseEmbedding):
+ # TODO: maybe a better way to do this?
+ def _aget_query_embedding(self, query: str) -> list[float]:
+ return [0.1] * self.get_query_embedding_dim()
+
+ def _get_query_embedding(self, query: str) -> list[float]:
+ return [0.1] * self.get_query_embedding_dim()
+
+ def _get_text_embedding(self, text: str) -> list[float]:
+ return [0.1] * self.get_query_embedding_dim()
+
+ def get_query_embedding_dim(self) -> int:
+ return 384 # Match your real FAISS config
+
+
+@pytest.mark.django_db
+def test_build_document_node(real_document):
+ nodes = indexing.build_document_node(real_document)
+ assert len(nodes) > 0
+ assert nodes[0].metadata["document_id"] == real_document.id
+
+
+@pytest.mark.django_db
+def test_rebuild_llm_index(
+ temp_llm_index_dir,
+ real_document,
+ mock_embed_model,
+):
+ with patch("documents.models.Document.objects.all") as mock_all:
+ mock_all.return_value = [real_document]
+ indexing.rebuild_llm_index(rebuild=True)
+
+ assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_add_or_update_document_updates_existing_entry(
+ temp_llm_index_dir,
+ real_document,
+ mock_embed_model,
+):
+ indexing.rebuild_llm_index(rebuild=True)
+ indexing.llm_index_add_or_update_document(real_document)
+
+ assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_remove_document_deletes_node_from_docstore(
+ temp_llm_index_dir,
+ real_document,
+ mock_embed_model,
+):
+ indexing.rebuild_llm_index(rebuild=True)
+ indexing.llm_index_add_or_update_document(real_document)
+ indexing.llm_index_remove_document(real_document)
+
+ assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_rebuild_llm_index_no_documents(
+ temp_llm_index_dir,
+ mock_embed_model,
+):
+ with patch("documents.models.Document.objects.all") as mock_all:
+ mock_all.return_value = []
+
+ with pytest.raises(RuntimeError, match="No nodes to index"):
+ indexing.rebuild_llm_index(rebuild=True)
+
+
+def test_query_similar_documents(
+ temp_llm_index_dir,
+ real_document,
+):
+ with (
+ patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
+ patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
+ patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
+ ):
+ mock_index = MagicMock()
+ mock_load_or_build_index.return_value = mock_index
+
+ mock_retriever = MagicMock()
+ mock_retriever_cls.return_value = mock_retriever
+
+ mock_node1 = MagicMock()
+ mock_node1.metadata = {"document_id": 1}
+
+ mock_node2 = MagicMock()
+ mock_node2.metadata = {"document_id": 2}
+
+ mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
+
+ mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
+ mock_filter.return_value = mock_filtered_docs
+
+ result = indexing.query_similar_documents(real_document, top_k=3)
+
+ mock_load_or_build_index.assert_called_once()
+ mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
+ mock_retriever.retrieve.assert_called_once_with(
+ "Test Document\nThis is some test content.",
+ )
+ mock_filter.assert_called_once_with(pk__in=[1, 2])
+
+ assert result == mock_filtered_docs
from unittest.mock import patch
import pytest
-from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless.ai.embedding import build_llm_index_text
from paperless.ai.embedding import get_embedding_dim
from paperless.ai.embedding import get_embedding_model
-from paperless.ai.indexing import load_index
-from paperless.ai.indexing import query_similar_documents
from paperless.ai.rag import get_context_for_document
from paperless.models import LLMEmbeddingBackend
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
-
-
-# Indexing
-
-
-@pytest.fixture
-def mock_settings(settings):
- settings.LLM_INDEX_DIR = "/fake/path"
- return settings
-
-
-class FakeEmbedding(BaseEmbedding):
- # TODO: gotta be a better way to do this
- def _aget_query_embedding(self, query: str) -> list[float]:
- return [0.1, 0.2, 0.3]
-
- def _get_query_embedding(self, query: str) -> list[float]:
- return [0.1, 0.2, 0.3]
-
- def _get_text_embedding(self, text: str) -> list[float]:
- return [0.1, 0.2, 0.3]
-
-
-def test_load_index(mock_settings):
- with (
- patch("paperless.ai.indexing.FaissVectorStore.from_persist_dir") as mock_faiss,
- patch("paperless.ai.indexing.get_embedding_model") as mock_get_embed_model,
- patch(
- "paperless.ai.indexing.StorageContext.from_defaults",
- ) as mock_storage_context,
- patch("paperless.ai.indexing.load_index_from_storage") as mock_load_index,
- ):
- # Setup mocks
- mock_vector_store = MagicMock()
- mock_storage = MagicMock()
- mock_index = MagicMock()
-
- mock_faiss.return_value = mock_vector_store
- mock_storage_context.return_value = mock_storage
- mock_load_index.return_value = mock_index
- mock_get_embed_model.return_value = FakeEmbedding()
-
- # Act
- result = load_index()
-
- # Assert
- mock_faiss.assert_called_once_with("/fake/path")
- mock_get_embed_model.assert_called_once()
- mock_storage_context.assert_called_once_with(
- vector_store=mock_vector_store,
- persist_dir="/fake/path",
- )
- mock_load_index.assert_called_once_with(mock_storage)
- assert result == mock_index
-
-
-def test_query_similar_documents(mock_document):
- with (
- patch("paperless.ai.indexing.load_index") as mock_load_index_func,
- patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
- patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
- ):
- # Setup mocks
- mock_index = MagicMock()
- mock_load_index_func.return_value = mock_index
-
- mock_retriever = MagicMock()
- mock_retriever_cls.return_value = mock_retriever
-
- mock_node1 = MagicMock()
- mock_node1.metadata = {"document_id": 1}
-
- mock_node2 = MagicMock()
- mock_node2.metadata = {"document_id": 2}
-
- mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
-
- mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
- mock_filter.return_value = mock_filtered_docs
-
- result = query_similar_documents(mock_document, top_k=3)
-
- mock_load_index_func.assert_called_once()
- mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
- mock_retriever.retrieve.assert_called_once_with(
- "Test Title\nThis is the document content.",
- )
- mock_filter.assert_called_once_with(pk__in=[1, 2])
-
- assert result == mock_filtered_docs