]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Handle doc updates, refactor
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sun, 27 Apr 2025 08:24:00 +0000 (01:24 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:03:58 +0000 (11:03 -0700)
src/documents/apps.py
src/documents/signals/handlers.py
src/documents/tasks.py
src/paperless/ai/chat.py
src/paperless/ai/indexing.py
src/paperless/config.py
src/paperless/tests/test_ai_chat.py
src/paperless/tests/test_ai_indexing.py [new file with mode: 0644]
src/paperless/tests/test_ai_rag.py

index f3b798c0b5bd21689010418c229f4a19a545b037..32e49b160dd17fd406a802ee46fb9666e94baa78 100644 (file)
@@ -11,6 +11,7 @@ class DocumentsConfig(AppConfig):
         from documents.signals import document_consumption_finished
         from documents.signals import document_updated
         from documents.signals.handlers import add_inbox_tags
+        from documents.signals.handlers import add_or_update_document_in_llm_index
         from documents.signals.handlers import add_to_index
         from documents.signals.handlers import run_workflows_added
         from documents.signals.handlers import run_workflows_updated
@@ -26,6 +27,7 @@ class DocumentsConfig(AppConfig):
         document_consumption_finished.connect(set_storage_path)
         document_consumption_finished.connect(add_to_index)
         document_consumption_finished.connect(run_workflows_added)
+        document_consumption_finished.connect(add_or_update_document_in_llm_index)
         document_updated.connect(run_workflows_updated)
 
         import documents.schema  # noqa: F401
index dd0f34f9277ae3f0f443996a0dd2f22c96a734f0..c6fb6ee1ecb930e08d26ae1913b9a4cb55fb9211 100644 (file)
@@ -48,6 +48,7 @@ from documents.models import WorkflowTrigger
 from documents.permissions import get_objects_for_user_owner_aware
 from documents.permissions import set_permissions_for_object
 from documents.templating.workflows import parse_w_workflow_placeholders
+from paperless.config import AIConfig
 
 if TYPE_CHECKING:
     from pathlib import Path
@@ -1449,3 +1450,26 @@ def task_failure_handler(
             task_instance.save()
     except Exception:  # pragma: no cover
         logger.exception("Updating PaperlessTask failed")
+
+
+def add_or_update_document_in_llm_index(sender, document, **kwargs):
+    """
+    Add or update a document in the LLM index when it is created or updated.
+    """
+    ai_config = AIConfig()
+    if ai_config.llm_index_enabled():
+        from documents.tasks import update_document_in_llm_index
+
+        update_document_in_llm_index.delay(document)
+
+
+@receiver(models.signals.post_delete, sender=Document)
+def delete_document_from_llm_index(sender, instance: Document, **kwargs):
+    """
+    Delete a document from the LLM index when it is deleted.
+    """
+    ai_config = AIConfig()
+    if ai_config.llm_index_enabled():
+        from documents.tasks import remove_document_from_llm_index
+
+        remove_document_from_llm_index.delay(instance)
index c2427929a1616074d7ea191961b607001db64f44..3edde40cb65e841fda94df9132bd3507061b20b4 100644 (file)
@@ -6,8 +6,6 @@ import uuid
 from pathlib import Path
 from tempfile import TemporaryDirectory
 
-import faiss
-import llama_index.core.settings as llama_settings
 import tqdm
 from celery import Task
 from celery import shared_task
@@ -19,13 +17,6 @@ from django.db import transaction
 from django.db.models.signals import post_save
 from django.utils import timezone
 from filelock import FileLock
-from llama_index.core import Document as LlamaDocument
-from llama_index.core import StorageContext
-from llama_index.core import VectorStoreIndex
-from llama_index.core.node_parser import SimpleNodeParser
-from llama_index.core.storage.docstore import SimpleDocumentStore
-from llama_index.core.storage.index_store import SimpleIndexStore
-from llama_index.vector_stores.faiss import FaissVectorStore
 from whoosh.writing import AsyncWriter
 
 from documents import index
@@ -63,9 +54,10 @@ from documents.sanity_checker import SanityCheckFailedException
 from documents.signals import document_updated
 from documents.signals.handlers import cleanup_document_deletion
 from documents.signals.handlers import run_workflows
-from paperless.ai.embedding import build_llm_index_text
-from paperless.ai.embedding import get_embedding_dim
-from paperless.ai.embedding import get_embedding_model
+from paperless.ai.indexing import llm_index_add_or_update_document
+from paperless.ai.indexing import llm_index_remove_document
+from paperless.ai.indexing import rebuild_llm_index
+from paperless.config import AIConfig
 
 if settings.AUDIT_LOG_ENABLED:
     from auditlog.models import LogEntry
@@ -254,6 +246,11 @@ def bulk_update_documents(document_ids):
         for doc in documents:
             index.update_document(writer, doc)
 
+    ai_config = AIConfig()
+    if ai_config.llm_index_enabled():
+        for doc in documents:
+            llm_index_add_or_update_document()
+
 
 @shared_task
 def update_document_content_maybe_archive_file(document_id):
@@ -353,6 +350,10 @@ def update_document_content_maybe_archive_file(document_id):
         with index.open_index_writer() as writer:
             index.update_document(writer, document)
 
+        ai_config = AIConfig()
+        if ai_config.llm_index_enabled:
+            llm_index_add_or_update_document(document)
+
         clear_document_caches(document.pk)
 
     except Exception:
@@ -532,60 +533,25 @@ def check_scheduled_workflows():
 
 
 def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
-    if rebuild:
-        shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
-        settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
+    rebuild_llm_index(
+        progress_bar_disable=progress_bar_disable,
+        rebuild=rebuild,
+    )
 
-    embed_model = get_embedding_model()
-    llama_settings.Settings.embed_model = embed_model
 
-    if rebuild or not settings.LLM_INDEX_DIR.exists():
-        embedding_dim = get_embedding_dim()
-        faiss_index = faiss.IndexFlatL2(embedding_dim)
-        vector_store = FaissVectorStore(faiss_index=faiss_index)
-    else:
-        vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
+@shared_task
+def update_document_in_llm_index(document):
+    llm_index_add_or_update_document(document)
 
-    docstore = SimpleDocumentStore()
-    index_store = SimpleIndexStore()
 
-    storage_context = StorageContext.from_defaults(
-        docstore=docstore,
-        index_store=index_store,
-        persist_dir=settings.LLM_INDEX_DIR,
-        vector_store=vector_store,
-    )
+@shared_task
+def remove_document_from_llm_index(document):
+    llm_index_remove_document(document)
 
-    parser = SimpleNodeParser()
-    nodes = []
-
-    for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
-        if not document.content:
-            continue
-
-        text = build_llm_index_text(document)
-        metadata = {
-            "document_id": document.id,
-            "title": document.title,
-            "tags": [t.name for t in document.tags.all()],
-            "correspondent": document.correspondent.name
-            if document.correspondent
-            else None,
-            "document_type": document.document_type.name
-            if document.document_type
-            else None,
-            "created": document.created.isoformat() if document.created else None,
-            "added": document.added.isoformat() if document.added else None,
-        }
-
-        doc = LlamaDocument(text=text, metadata=metadata)
-        doc_nodes = parser.get_nodes_from_documents([doc])
-        nodes.extend(doc_nodes)
-
-    index = VectorStoreIndex(
-        nodes=nodes,
-        storage_context=storage_context,
-        embed_model=embed_model,
-    )
 
-    index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+# TODO: schedule to run periodically
+@shared_task
+def rebuild_llm_index_task():
+    from paperless.ai.indexing import rebuild_llm_index
+
+    rebuild_llm_index(rebuild=True)
index 7141177d7d2e9890a25cd41a070749e7d35bbd71..45d44db8cdf486dc397325f443b00c73c5e2dd1c 100644 (file)
@@ -7,7 +7,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine
 
 from documents.models import Document
 from paperless.ai.client import AIClient
-from paperless.ai.indexing import load_index
+from paperless.ai.indexing import load_or_build_index
 
 logger = logging.getLogger("paperless.ai.chat")
 
@@ -24,7 +24,7 @@ CHAT_PROMPT_TMPL = PromptTemplate(
 
 def stream_chat_with_documents(query_str: str, documents: list[Document]):
     client = AIClient()
-    index = load_index()
+    index = load_or_build_index()
 
     doc_ids = [doc.pk for doc in documents]
 
index 6d9a59e792be64cb2317072f11036fe2fce82213..4742ca0ab26ceb0e82a9720a25f216408e0675a4 100644 (file)
 import logging
+import shutil
 
+import faiss
 import llama_index.core.settings as llama_settings
+import tqdm
 from django.conf import settings
+from llama_index.core import Document as LlamaDocument
 from llama_index.core import StorageContext
 from llama_index.core import VectorStoreIndex
-from llama_index.core import load_index_from_storage
+from llama_index.core.node_parser import SimpleNodeParser
 from llama_index.core.retrievers import VectorIndexRetriever
+from llama_index.core.schema import BaseNode
+from llama_index.core.storage.docstore import SimpleDocumentStore
+from llama_index.core.storage.index_store import SimpleIndexStore
 from llama_index.vector_stores.faiss import FaissVectorStore
 
 from documents.models import Document
+from paperless.ai.embedding import build_llm_index_text
+from paperless.ai.embedding import get_embedding_dim
 from paperless.ai.embedding import get_embedding_model
 
 logger = logging.getLogger("paperless.ai.indexing")
 
 
-def load_index() -> VectorStoreIndex:
-    """Loads the persisted LlamaIndex from disk."""
-    vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
-    embed_model = get_embedding_model()
+def get_or_create_storage_context(*, rebuild=False):
+    """
+    Loads or creates the StorageContext (vector store, docstore, index store).
+    If rebuild=True, deletes and recreates everything.
+    """
+    if rebuild:
+        shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
+        settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
 
-    llama_settings.Settings.embed_model = embed_model
-    llama_settings.Settings.chunk_size = 512
+    if rebuild or not settings.LLM_INDEX_DIR.exists():
+        embedding_dim = get_embedding_dim()
+        faiss_index = faiss.IndexFlatL2(embedding_dim)
+        vector_store = FaissVectorStore(faiss_index=faiss_index)
+        docstore = SimpleDocumentStore()
+        index_store = SimpleIndexStore()
+    else:
+        vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
+        docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
+        index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
 
-    storage_context = StorageContext.from_defaults(
+    return StorageContext.from_defaults(
+        docstore=docstore,
+        index_store=index_store,
         vector_store=vector_store,
         persist_dir=settings.LLM_INDEX_DIR,
     )
-    return load_index_from_storage(storage_context)
+
+
+def get_vector_store_index(storage_context, embed_model):
+    """
+    Returns a VectorStoreIndex given a storage context and embed model.
+    """
+    return VectorStoreIndex(
+        storage_context=storage_context,
+        embed_model=embed_model,
+    )
+
+
+def build_document_node(document) -> list[BaseNode]:
+    """
+    Given a Document, returns parsed Nodes ready for indexing.
+    """
+    if not document.content:
+        return []
+
+    text = build_llm_index_text(document)
+    metadata = {
+        "document_id": document.id,
+        "title": document.title,
+        "tags": [t.name for t in document.tags.all()],
+        "correspondent": document.correspondent.name
+        if document.correspondent
+        else None,
+        "document_type": document.document_type.name
+        if document.document_type
+        else None,
+        "created": document.created.isoformat() if document.created else None,
+        "added": document.added.isoformat() if document.added else None,
+    }
+    doc = LlamaDocument(text=text, metadata=metadata)
+    parser = SimpleNodeParser()
+    return parser.get_nodes_from_documents([doc])
+
+
+def load_or_build_index(storage_context, embed_model, nodes=None):
+    """
+    Load an existing VectorStoreIndex if present,
+    or build a new one using provided nodes if storage is empty.
+    """
+    try:
+        return VectorStoreIndex(
+            storage_context=storage_context,
+            embed_model=embed_model,
+        )
+    except ValueError as e:
+        if "One of nodes, objects, or index_struct must be provided" in str(e):
+            if not nodes:
+                return None
+            return VectorStoreIndex(
+                nodes=nodes,
+                storage_context=storage_context,
+                embed_model=embed_model,
+            )
+        raise
+
+
+def remove_existing_document_nodes(document, index):
+    """
+    Removes existing documents from docstore for a given document from the index.
+    This is necessary because FAISS IndexFlatL2 is append-only.
+    """
+    all_node_ids = list(index.docstore.docs.keys())
+    existing_nodes = [
+        node.node_id
+        for node in index.docstore.get_nodes(all_node_ids)
+        if node.metadata.get("document_id") == document.id
+    ]
+    for node_id in existing_nodes:
+        # Delete from docstore, FAISS IndexFlatL2 are append-only
+        index.docstore.delete_document(node_id)
+
+
+def rebuild_llm_index(*, progress_bar_disable=False, rebuild=False):
+    """
+    Rebuilds the LLM index from scratch.
+    """
+    embed_model = get_embedding_model()
+    llama_settings.Settings.embed_model = embed_model
+
+    storage_context = get_or_create_storage_context(rebuild=rebuild)
+
+    nodes = []
+
+    for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
+        document_nodes = build_document_node(document)
+        nodes.extend(document_nodes)
+
+    if not nodes:
+        raise RuntimeError(
+            "No nodes to index — check that documents are available and have content.",
+        )
+
+    VectorStoreIndex(
+        nodes=nodes,
+        storage_context=storage_context,
+        embed_model=embed_model,
+    )
+    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+
+
+def llm_index_add_or_update_document(document):
+    """
+    Adds or updates a document in the LLM index.
+    If the document already exists, it will be replaced.
+    """
+    embed_model = get_embedding_model()
+    llama_settings.Settings.embed_model = embed_model
+
+    storage_context = get_or_create_storage_context(rebuild=False)
+
+    new_nodes = build_document_node(document)
+
+    index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
+
+    if index is None:
+        # Nothing to index
+        return
+
+    # Remove old nodes
+    remove_existing_document_nodes(document, index)
+
+    index.insert_nodes(new_nodes)
+
+    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+
+
+def llm_index_remove_document(document):
+    embed_model = get_embedding_model()
+    llama_settings.embed_model = embed_model
+
+    storage_context = get_or_create_storage_context(rebuild=False)
+
+    index = load_or_build_index(storage_context, embed_model)
+    if index is None:
+        return  # Nothing to remove
+
+    # Remove old nodes
+    remove_existing_document_nodes(document, index)
+
+    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
 
 
 def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
-    """Runs a similarity query and returns top-k similar Document objects."""
-    # Load the index
-    index = load_index()
+    """
+    Runs a similarity query and returns top-k similar Document objects.
+    """
+    index = load_or_build_index()
     retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
 
     # Build query from the document text
     query_text = (document.title or "") + "\n" + (document.content or "")
-
-    # Query
     results = retriever.retrieve(query_text)
 
     # Each result.node.metadata["document_id"] should match our stored doc
index 2c38b40561d141218664176aff1c788f76df043c..fc4fe23cf5efb6868b4a44d11ebc1a39008dce3d 100644 (file)
@@ -199,3 +199,10 @@ class AIConfig(BaseConfig):
         self.llm_model = app_config.llm_model or settings.LLM_MODEL
         self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
         self.llm_url = app_config.llm_url or settings.LLM_URL
+
+    def llm_index_enabled(self) -> bool:
+        return (
+            self.ai_enabled
+            and self.llm_embedding_backend
+            and self.llm_embedding_backend
+        )
index 2b792c4c803263efe84ca40ef191aa862a877c67..e70de458ff5d18893e9fd6bfb658ae3fb8f644eb 100644 (file)
@@ -45,7 +45,7 @@ def mock_document():
 def test_stream_chat_with_one_document_full_content(mock_document):
     with (
         patch("paperless.ai.chat.AIClient") as mock_client_cls,
-        patch("paperless.ai.chat.load_index") as mock_load_index,
+        patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
         patch(
             "paperless.ai.chat.RetrieverQueryEngine.from_args",
         ) as mock_query_engine_cls,
@@ -76,7 +76,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
 def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
     with (
         patch("paperless.ai.chat.AIClient") as mock_client_cls,
-        patch("paperless.ai.chat.load_index") as mock_load_index,
+        patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
         patch(
             "paperless.ai.chat.RetrieverQueryEngine.from_args",
         ) as mock_query_engine_cls,
@@ -126,7 +126,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
 def test_stream_chat_no_matching_nodes():
     with (
         patch("paperless.ai.chat.AIClient") as mock_client_cls,
-        patch("paperless.ai.chat.load_index") as mock_load_index,
+        patch("paperless.ai.chat.load_or_build_index") as mock_load_index,
     ):
         mock_client = MagicMock()
         mock_client_cls.return_value = mock_client
diff --git a/src/paperless/tests/test_ai_indexing.py b/src/paperless/tests/test_ai_indexing.py
new file mode 100644 (file)
index 0000000..24cdeda
--- /dev/null
@@ -0,0 +1,144 @@
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import pytest
+from django.utils import timezone
+from llama_index.core.base.embeddings.base import BaseEmbedding
+
+from documents.models import Document
+from paperless.ai import indexing
+
+
+@pytest.fixture
+def temp_llm_index_dir(tmp_path):
+    original_dir = indexing.settings.LLM_INDEX_DIR
+    indexing.settings.LLM_INDEX_DIR = tmp_path
+    yield tmp_path
+    indexing.settings.LLM_INDEX_DIR = original_dir
+
+
+@pytest.fixture
+def real_document(db):
+    return Document.objects.create(
+        title="Test Document",
+        content="This is some test content.",
+        added=timezone.now(),
+    )
+
+
+@pytest.fixture
+def mock_embed_model():
+    """Mocks the embedding model."""
+    with patch("paperless.ai.indexing.get_embedding_model") as mock:
+        mock.return_value = FakeEmbedding()
+        yield mock
+
+
+class FakeEmbedding(BaseEmbedding):
+    # TODO: maybe a better way to do this?
+    def _aget_query_embedding(self, query: str) -> list[float]:
+        return [0.1] * self.get_query_embedding_dim()
+
+    def _get_query_embedding(self, query: str) -> list[float]:
+        return [0.1] * self.get_query_embedding_dim()
+
+    def _get_text_embedding(self, text: str) -> list[float]:
+        return [0.1] * self.get_query_embedding_dim()
+
+    def get_query_embedding_dim(self) -> int:
+        return 384  # Match your real FAISS config
+
+
+@pytest.mark.django_db
+def test_build_document_node(real_document):
+    nodes = indexing.build_document_node(real_document)
+    assert len(nodes) > 0
+    assert nodes[0].metadata["document_id"] == real_document.id
+
+
+@pytest.mark.django_db
+def test_rebuild_llm_index(
+    temp_llm_index_dir,
+    real_document,
+    mock_embed_model,
+):
+    with patch("documents.models.Document.objects.all") as mock_all:
+        mock_all.return_value = [real_document]
+        indexing.rebuild_llm_index(rebuild=True)
+
+        assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_add_or_update_document_updates_existing_entry(
+    temp_llm_index_dir,
+    real_document,
+    mock_embed_model,
+):
+    indexing.rebuild_llm_index(rebuild=True)
+    indexing.llm_index_add_or_update_document(real_document)
+
+    assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_remove_document_deletes_node_from_docstore(
+    temp_llm_index_dir,
+    real_document,
+    mock_embed_model,
+):
+    indexing.rebuild_llm_index(rebuild=True)
+    indexing.llm_index_add_or_update_document(real_document)
+    indexing.llm_index_remove_document(real_document)
+
+    assert any(temp_llm_index_dir.glob("*.json"))
+
+
+@pytest.mark.django_db
+def test_rebuild_llm_index_no_documents(
+    temp_llm_index_dir,
+    mock_embed_model,
+):
+    with patch("documents.models.Document.objects.all") as mock_all:
+        mock_all.return_value = []
+
+        with pytest.raises(RuntimeError, match="No nodes to index"):
+            indexing.rebuild_llm_index(rebuild=True)
+
+
+def test_query_similar_documents(
+    temp_llm_index_dir,
+    real_document,
+):
+    with (
+        patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
+        patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
+        patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
+    ):
+        mock_index = MagicMock()
+        mock_load_or_build_index.return_value = mock_index
+
+        mock_retriever = MagicMock()
+        mock_retriever_cls.return_value = mock_retriever
+
+        mock_node1 = MagicMock()
+        mock_node1.metadata = {"document_id": 1}
+
+        mock_node2 = MagicMock()
+        mock_node2.metadata = {"document_id": 2}
+
+        mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
+
+        mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
+        mock_filter.return_value = mock_filtered_docs
+
+        result = indexing.query_similar_documents(real_document, top_k=3)
+
+        mock_load_or_build_index.assert_called_once()
+        mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
+        mock_retriever.retrieve.assert_called_once_with(
+            "Test Document\nThis is some test content.",
+        )
+        mock_filter.assert_called_once_with(pk__in=[1, 2])
+
+        assert result == mock_filtered_docs
index 0cd2f9c9c596ceb311eb15d3f7c6bd92bedd881d..3d0a76f97238dccf672174298a4fe882aceb5f72 100644 (file)
@@ -2,14 +2,11 @@ from unittest.mock import MagicMock
 from unittest.mock import patch
 
 import pytest
-from llama_index.core.base.embeddings.base import BaseEmbedding
 
 from documents.models import Document
 from paperless.ai.embedding import build_llm_index_text
 from paperless.ai.embedding import get_embedding_dim
 from paperless.ai.embedding import get_embedding_model
-from paperless.ai.indexing import load_index
-from paperless.ai.indexing import query_similar_documents
 from paperless.ai.rag import get_context_for_document
 from paperless.models import LLMEmbeddingBackend
 
@@ -182,93 +179,3 @@ def test_build_llm_index_text(mock_document):
         assert "Notes: Note1,Note2" in result
         assert "Content:\n\nThis is the document content." in result
         assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
-
-
-# Indexing
-
-
-@pytest.fixture
-def mock_settings(settings):
-    settings.LLM_INDEX_DIR = "/fake/path"
-    return settings
-
-
-class FakeEmbedding(BaseEmbedding):
-    # TODO: gotta be a better way to do this
-    def _aget_query_embedding(self, query: str) -> list[float]:
-        return [0.1, 0.2, 0.3]
-
-    def _get_query_embedding(self, query: str) -> list[float]:
-        return [0.1, 0.2, 0.3]
-
-    def _get_text_embedding(self, text: str) -> list[float]:
-        return [0.1, 0.2, 0.3]
-
-
-def test_load_index(mock_settings):
-    with (
-        patch("paperless.ai.indexing.FaissVectorStore.from_persist_dir") as mock_faiss,
-        patch("paperless.ai.indexing.get_embedding_model") as mock_get_embed_model,
-        patch(
-            "paperless.ai.indexing.StorageContext.from_defaults",
-        ) as mock_storage_context,
-        patch("paperless.ai.indexing.load_index_from_storage") as mock_load_index,
-    ):
-        # Setup mocks
-        mock_vector_store = MagicMock()
-        mock_storage = MagicMock()
-        mock_index = MagicMock()
-
-        mock_faiss.return_value = mock_vector_store
-        mock_storage_context.return_value = mock_storage
-        mock_load_index.return_value = mock_index
-        mock_get_embed_model.return_value = FakeEmbedding()
-
-        # Act
-        result = load_index()
-
-        # Assert
-        mock_faiss.assert_called_once_with("/fake/path")
-        mock_get_embed_model.assert_called_once()
-        mock_storage_context.assert_called_once_with(
-            vector_store=mock_vector_store,
-            persist_dir="/fake/path",
-        )
-        mock_load_index.assert_called_once_with(mock_storage)
-        assert result == mock_index
-
-
-def test_query_similar_documents(mock_document):
-    with (
-        patch("paperless.ai.indexing.load_index") as mock_load_index_func,
-        patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
-        patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
-    ):
-        # Setup mocks
-        mock_index = MagicMock()
-        mock_load_index_func.return_value = mock_index
-
-        mock_retriever = MagicMock()
-        mock_retriever_cls.return_value = mock_retriever
-
-        mock_node1 = MagicMock()
-        mock_node1.metadata = {"document_id": 1}
-
-        mock_node2 = MagicMock()
-        mock_node2.metadata = {"document_id": 2}
-
-        mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
-
-        mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
-        mock_filter.return_value = mock_filtered_docs
-
-        result = query_similar_documents(mock_document, top_k=3)
-
-        mock_load_index_func.assert_called_once()
-        mock_retriever_cls.assert_called_once_with(index=mock_index, similarity_top_k=3)
-        mock_retriever.retrieve.assert_called_once_with(
-            "Test Title\nThis is the document content.",
-        )
-        mock_filter.assert_called_once_with(pk__in=[1, 2])
-
-        assert result == mock_filtered_docs