]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Refactor load_or_build_index
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 29 Apr 2025 04:39:39 +0000 (21:39 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:04:02 +0000 (11:04 -0700)
src/paperless/ai/indexing.py
src/paperless/tests/test_ai_indexing.py

index 840d58f3709c684da4023bb3b467dcf270304021..9a32409cac1207cc48c4336daa3a0a22a4487ac8 100644 (file)
@@ -76,11 +76,14 @@ def build_document_node(document: Document) -> list[BaseNode]:
     return parser.get_nodes_from_documents([doc])
 
 
-def load_or_build_index(storage_context: StorageContext, embed_model, nodes=None):
+def load_or_build_index(nodes=None):
     """
     Load an existing VectorStoreIndex if present,
     or build a new one using provided nodes if storage is empty.
     """
+    embed_model = get_embedding_model()
+    llama_settings.Settings.embed_model = embed_model
+    storage_context = get_or_create_storage_context()
     try:
         return load_index_from_storage(storage_context=storage_context)
     except ValueError as e:
@@ -115,10 +118,6 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
     """
     Rebuild or update the LLM index.
     """
-    embed_model = get_embedding_model()
-    llama_settings.Settings.embed_model = embed_model
-    storage_context = get_or_create_storage_context(rebuild=rebuild)
-
     nodes = []
 
     documents = Document.objects.all()
@@ -127,12 +126,15 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
         return
 
     if rebuild:
+        embed_model = get_embedding_model()
+        llama_settings.Settings.embed_model = embed_model
+        storage_context = get_or_create_storage_context(rebuild=rebuild)
         # Rebuild index from scratch
         for document in tqdm.tqdm(documents, disable=progress_bar_disable):
             document_nodes = build_document_node(document)
             nodes.extend(document_nodes)
 
-        VectorStoreIndex(
+        index = VectorStoreIndex(
             nodes=nodes,
             storage_context=storage_context,
             embed_model=embed_model,
@@ -140,7 +142,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
         )
     else:
         # Update existing index
-        index = load_or_build_index(storage_context, embed_model)
+        index = load_or_build_index()
         all_node_ids = list(index.docstore.docs.keys())
         existing_nodes = {
             node.metadata.get("document_id"): node
@@ -174,7 +176,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False):
         else:
             logger.info("No changes detected, skipping llm index rebuild.")
 
-    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+    index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
 
 
 def llm_index_add_or_update_document(document: Document):
@@ -182,46 +184,33 @@ def llm_index_add_or_update_document(document: Document):
     Adds or updates a document in the LLM index.
     If the document already exists, it will be replaced.
     """
-    embed_model = get_embedding_model()
-    llama_settings.Settings.embed_model = embed_model
-
-    storage_context = get_or_create_storage_context(rebuild=False)
-
     new_nodes = build_document_node(document)
 
-    index = load_or_build_index(storage_context, embed_model, nodes=new_nodes)
+    index = load_or_build_index(nodes=new_nodes)
 
     remove_document_docstore_nodes(document, index)
 
     index.insert_nodes(new_nodes)
 
-    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+    index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
 
 
 def llm_index_remove_document(document: Document):
     """
     Removes a document from the LLM index.
     """
-    embed_model = get_embedding_model()
-    llama_settings.embed_model = embed_model
-
-    storage_context = get_or_create_storage_context(rebuild=False)
-
-    index = load_or_build_index(storage_context, embed_model)
+    index = load_or_build_index()
 
     remove_document_docstore_nodes(document, index)
 
-    storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
+    index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
 
 
 def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
     """
     Runs a similarity query and returns top-k similar Document objects.
     """
-    storage_context = get_or_create_storage_context(rebuild=False)
-    embed_model = get_embedding_model()
-    llama_settings.embed_model = embed_model
-    index = load_or_build_index(storage_context, embed_model)
+    index = load_or_build_index()
     retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
 
     query_text = (document.title or "") + "\n" + (document.content or "")
index d7b83316db882758aa3d33448b399e9d5f931ebe..101fdfb9ef3ba4ce2a182551ef43522daa6f8370 100644 (file)
@@ -131,12 +131,13 @@ def test_get_or_create_storage_context_raises_exception(
         indexing.get_or_create_storage_context(rebuild=False)
 
 
+@override_settings(
+    LLM_EMBEDDING_BACKEND="huggingface",
+)
 def test_load_or_build_index_builds_when_nodes_given(
     temp_llm_index_dir,
-    mock_embed_model,
     real_document,
 ):
-    storage_context = MagicMock()
     with patch(
         "paperless.ai.indexing.load_index_from_storage",
         side_effect=ValueError("Index not found"),
@@ -145,25 +146,26 @@ def test_load_or_build_index_builds_when_nodes_given(
             "paperless.ai.indexing.VectorStoreIndex",
             return_value=MagicMock(),
         ) as mock_index_cls:
-            indexing.load_or_build_index(
-                storage_context,
-                mock_embed_model,
-                nodes=[indexing.build_document_node(real_document)],
-            )
-            mock_index_cls.assert_called_once()
+            with patch(
+                "paperless.ai.indexing.get_or_create_storage_context",
+                return_value=MagicMock(),
+            ) as mock_storage:
+                mock_storage.return_value.persist_dir = temp_llm_index_dir
+                indexing.load_or_build_index(
+                    nodes=[indexing.build_document_node(real_document)],
+                )
+                mock_index_cls.assert_called_once()
 
 
 def test_load_or_build_index_raises_exception_when_no_nodes(
     temp_llm_index_dir,
-    mock_embed_model,
 ):
-    storage_context = MagicMock()
     with patch(
         "paperless.ai.indexing.load_index_from_storage",
         side_effect=ValueError("Index not found"),
     ):
         with pytest.raises(Exception):
-            indexing.load_or_build_index(storage_context, mock_embed_model)
+            indexing.load_or_build_index()
 
 
 @pytest.mark.django_db
@@ -185,13 +187,11 @@ def test_remove_document_deletes_node_from_docstore(
     mock_embed_model,
 ):
     indexing.update_llm_index(rebuild=True)
-    storage_context = indexing.get_or_create_storage_context()
-    index = indexing.load_or_build_index(storage_context, mock_embed_model)
+    index = indexing.load_or_build_index()
     assert len(index.docstore.docs) == 1
 
     indexing.llm_index_remove_document(real_document)
-    storage_context = indexing.get_or_create_storage_context()
-    index = indexing.load_or_build_index(storage_context, mock_embed_model)
+    index = indexing.load_or_build_index()
     assert len(index.docstore.docs) == 0