]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Refactor and consolidate rag / embedding and tests
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 29 Apr 2025 00:36:23 +0000 (17:36 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:04:01 +0000 (11:04 -0700)
src/paperless/ai/ai_classifier.py
src/paperless/ai/embedding.py
src/paperless/ai/indexing.py
src/paperless/ai/rag.py [deleted file]
src/paperless/tests/test_ai_chat.py
src/paperless/tests/test_ai_classifier.py
src/paperless/tests/test_ai_embedding.py [moved from src/paperless/tests/test_ai_rag.py with 78% similarity]
src/paperless/tests/test_ai_indexing.py

index ab349c81beb159420b1804212d0f80b6cce0339b..6b34c089994398ea58ee11290c76e7ad0d366155 100644 (file)
@@ -5,7 +5,7 @@ from llama_index.core.base.llms.types import CompletionResponse
 
 from documents.models import Document
 from paperless.ai.client import AIClient
-from paperless.ai.rag import get_context_for_document
+from paperless.ai.indexing import query_similar_documents
 from paperless.config import AIConfig
 
 logger = logging.getLogger("paperless.ai.rag_classifier")
@@ -65,6 +65,16 @@ def build_prompt_with_rag(document: Document) -> str:
     return prompt
 
 
+def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
+    similar_docs = query_similar_documents(doc)[:max_docs]
+    context_blocks = []
+    for similar in similar_docs:
+        text = similar.content or ""
+        title = similar.title or similar.filename or "Untitled"
+        context_blocks.append(f"TITLE: {title}\n{text}")
+    return "\n\n".join(context_blocks)
+
+
 def parse_ai_response(response: CompletionResponse) -> dict:
     try:
         raw = json.loads(response.text)
index 9d6a5faef65f8a327de24c9fac65b7e439942fab..e151a58867c3ad00161dc32d6f2f495efa21971d 100644 (file)
@@ -1,3 +1,4 @@
+from llama_index.core.base.embeddings.base import BaseEmbedding
 from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 from llama_index.embeddings.openai import OpenAIEmbedding
 
@@ -12,7 +13,7 @@ EMBEDDING_DIMENSIONS = {
 }
 
 
-def get_embedding_model():
+def get_embedding_model() -> BaseEmbedding:
     config = AIConfig()
 
     match config.llm_embedding_backend:
index 95442e55b9c8a2a5ea6f0cca5577d816b6a4a6b2..bc275c83feae48587a4b7e467dbe1e80c69928b5 100644 (file)
@@ -223,7 +223,10 @@ def query_similar_documents(document: Document, top_k: int = 5) -> list[Document
     """
     Runs a similarity query and returns top-k similar Document objects.
     """
-    index = load_or_build_index()
+    storage_context = get_or_create_storage_context(rebuild=False)
+    embed_model = get_embedding_model()
+    llama_settings.embed_model = embed_model
+    index = load_or_build_index(storage_context, embed_model)
     retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
 
     query_text = (document.title or "") + "\n" + (document.content or "")
diff --git a/src/paperless/ai/rag.py b/src/paperless/ai/rag.py
deleted file mode 100644 (file)
index 9b5baf4..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-from documents.models import Document
-from paperless.ai.indexing import query_similar_documents
-
-
-def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
-    similar_docs = query_similar_documents(doc)[:max_docs]
-    context_blocks = []
-    for similar in similar_docs:
-        text = similar.content or ""
-        title = similar.title or similar.filename or "Untitled"
-        context_blocks.append(f"TITLE: {title}\n{text}")
-    return "\n\n".join(context_blocks)
index e70de458ff5d18893e9fd6bfb658ae3fb8f644eb..d7f05b8d3e46f6d00f2634139b81ab1646da6560 100644 (file)
@@ -56,7 +56,7 @@ def test_stream_chat_with_one_document_full_content(mock_document):
 
         mock_node = TextNode(
             text="This is node content.",
-            metadata={"document_id": mock_document.pk, "title": "Test Document"},
+            metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
         )
         mock_index = MagicMock()
         mock_index.docstore.docs.values.return_value = [mock_node]
@@ -90,11 +90,11 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
         # Create two real TextNodes
         mock_node1 = TextNode(
             text="Content for doc 1.",
-            metadata={"document_id": 1, "title": "Document 1"},
+            metadata={"document_id": "1", "title": "Document 1"},
         )
         mock_node2 = TextNode(
             text="Content for doc 2.",
-            metadata={"document_id": 2, "title": "Document 2"},
+            metadata={"document_id": "2", "title": "Document 2"},
         )
         mock_index = MagicMock()
         mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
index a29f1a07edd189f9a0ab7c3fad0236cfda42fb1c..ef749cda9a504017b634b22ed17db4c54ea051f9 100644 (file)
@@ -9,12 +9,43 @@ from documents.models import Document
 from paperless.ai.ai_classifier import build_prompt_with_rag
 from paperless.ai.ai_classifier import build_prompt_without_rag
 from paperless.ai.ai_classifier import get_ai_document_classification
+from paperless.ai.ai_classifier import get_context_for_document
 from paperless.ai.ai_classifier import parse_ai_response
 
 
 @pytest.fixture
 def mock_document():
-    return Document(filename="test.pdf", content="This is a test document content.")
+    doc = MagicMock(spec=Document)
+    doc.title = "Test Title"
+    doc.filename = "test_file.pdf"
+    doc.created = "2023-01-01"
+    doc.added = "2023-01-02"
+    doc.modified = "2023-01-03"
+
+    tag1 = MagicMock()
+    tag1.name = "Tag1"
+    tag2 = MagicMock()
+    tag2.name = "Tag2"
+    doc.tags.all = MagicMock(return_value=[tag1, tag2])
+
+    doc.document_type = MagicMock()
+    doc.document_type.name = "Invoice"
+    doc.correspondent = MagicMock()
+    doc.correspondent.name = "Test Correspondent"
+    doc.archive_serial_number = "12345"
+    doc.content = "This is the document content."
+
+    cf1 = MagicMock(__str__=lambda x: "Value1")
+    cf1.field = MagicMock()
+    cf1.field.name = "Field1"
+    cf1.value = "Value1"
+    cf2 = MagicMock(__str__=lambda x: "Value2")
+    cf2.field = MagicMock()
+    cf2.field.name = "Field2"
+    cf2.value = "Value2"
+    doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
+
+    return doc
 
 
 @pytest.mark.django_db
@@ -105,13 +136,63 @@ def test_use_without_rag_if_not_configured(
     mock_build_prompt_without_rag.assert_called_once()
 
 
+@pytest.mark.django_db
 @override_settings(
+    LLM_EMBEDDING_BACKEND="huggingface",
     LLM_BACKEND="ollama",
     LLM_MODEL="some_model",
 )
 def test_prompt_with_without_rag(mock_document):
-    prompt = build_prompt_without_rag(mock_document)
-    assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
+    with patch(
+        "paperless.ai.ai_classifier.get_context_for_document",
+        return_value="Context from similar documents",
+    ):
+        prompt = build_prompt_without_rag(mock_document)
+        assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
+
+        prompt = build_prompt_with_rag(mock_document)
+        assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
+
+
+@pytest.fixture
+def mock_similar_documents():
+    doc1 = MagicMock()
+    doc1.content = "Content of document 1"
+    doc1.title = "Title 1"
+    doc1.filename = "file1.txt"
+
+    doc2 = MagicMock()
+    doc2.content = "Content of document 2"
+    doc2.title = None
+    doc2.filename = "file2.txt"
+
+    doc3 = MagicMock()
+    doc3.content = None
+    doc3.title = None
+    doc3.filename = None
+
+    return [doc1, doc2, doc3]
+
+
+@patch("paperless.ai.ai_classifier.query_similar_documents")
+def test_get_context_for_document(
+    mock_query_similar_documents,
+    mock_document,
+    mock_similar_documents,
+):
+    mock_query_similar_documents.return_value = mock_similar_documents
+
+    result = get_context_for_document(mock_document, max_docs=2)
+
+    expected_result = (
+        "TITLE: Title 1\nContent of document 1\n\n"
+        "TITLE: file2.txt\nContent of document 2"
+    )
+    assert result == expected_result
+    mock_query_similar_documents.assert_called_once()
+
 
-    prompt = build_prompt_with_rag(mock_document)
-    assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
+def test_get_context_for_document_no_similar_docs(mock_document):
+    with patch("paperless.ai.ai_classifier.query_similar_documents", return_value=[]):
+        result = get_context_for_document(mock_document)
+        assert result == ""
similarity index 78%
rename from src/paperless/tests/test_ai_rag.py
rename to src/paperless/tests/test_ai_embedding.py
index 3d0a76f97238dccf672174298a4fe882aceb5f72..75e2d791e4c7f9292fd805c44e88ca7b7777314e 100644 (file)
@@ -7,10 +7,15 @@ from documents.models import Document
 from paperless.ai.embedding import build_llm_index_text
 from paperless.ai.embedding import get_embedding_dim
 from paperless.ai.embedding import get_embedding_model
-from paperless.ai.rag import get_context_for_document
 from paperless.models import LLMEmbeddingBackend
 
 
+@pytest.fixture
+def mock_ai_config():
+    with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
+        yield MockAIConfig
+
+
 @pytest.fixture
 def mock_document():
     doc = MagicMock(spec=Document)
@@ -46,59 +51,6 @@ def mock_document():
     return doc
 
 
-@pytest.fixture
-def mock_similar_documents():
-    doc1 = MagicMock()
-    doc1.content = "Content of document 1"
-    doc1.title = "Title 1"
-    doc1.filename = "file1.txt"
-
-    doc2 = MagicMock()
-    doc2.content = "Content of document 2"
-    doc2.title = None
-    doc2.filename = "file2.txt"
-
-    doc3 = MagicMock()
-    doc3.content = None
-    doc3.title = None
-    doc3.filename = None
-
-    return [doc1, doc2, doc3]
-
-
-@patch("paperless.ai.rag.query_similar_documents")
-def test_get_context_for_document(
-    mock_query_similar_documents,
-    mock_document,
-    mock_similar_documents,
-):
-    mock_query_similar_documents.return_value = mock_similar_documents
-
-    result = get_context_for_document(mock_document, max_docs=2)
-
-    expected_result = (
-        "TITLE: Title 1\nContent of document 1\n\n"
-        "TITLE: file2.txt\nContent of document 2"
-    )
-    assert result == expected_result
-    mock_query_similar_documents.assert_called_once()
-
-
-def test_get_context_for_document_no_similar_docs(mock_document):
-    with patch("paperless.ai.rag.query_similar_documents", return_value=[]):
-        result = get_context_for_document(mock_document)
-        assert result == ""
-
-
-# Embedding
-
-
-@pytest.fixture
-def mock_ai_config():
-    with patch("paperless.ai.embedding.AIConfig") as MockAIConfig:
-        yield MockAIConfig
-
-
 def test_get_embedding_model_openai(mock_ai_config):
     mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
     mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
index 73df742b1e075ff8db59d2f4e9f7c9ac11a19bed..c0279171a049753df5f2bb961ae73224b8b6d392 100644 (file)
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
 from unittest.mock import patch
 
 import pytest
+from django.test import override_settings
 from django.utils import timezone
 from llama_index.core.base.embeddings.base import BaseEmbedding
 
@@ -162,15 +163,23 @@ def test_update_llm_index_no_documents(
             )
 
 
+@override_settings(
+    LLM_EMBEDDING_BACKEND="huggingface",
+    LLM_BACKEND="ollama",
+)
 def test_query_similar_documents(
     temp_llm_index_dir,
     real_document,
 ):
     with (
+        patch("paperless.ai.indexing.get_or_create_storage_context") as mock_storage,
         patch("paperless.ai.indexing.load_or_build_index") as mock_load_or_build_index,
         patch("paperless.ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
         patch("paperless.ai.indexing.Document.objects.filter") as mock_filter,
     ):
+        mock_storage.return_value = MagicMock()
+        mock_storage.return_value.persist_dir = temp_llm_index_dir
+
         mock_index = MagicMock()
         mock_load_or_build_index.return_value = mock_index