]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Chat coverage
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sat, 26 Apr 2025 08:18:37 +0000 (01:18 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:03:58 +0000 (11:03 -0700)
src/paperless/ai/chat.py
src/paperless/tests/test_ai_chat.py [new file with mode: 0644]

index ad14bda4db10f69db95a279c97ddc07852672b14..7141177d7d2e9890a25cd41a070749e7d35bbd71 100644 (file)
@@ -37,7 +37,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
 
     if len(nodes) == 0:
         logger.warning("No nodes found for the given documents.")
-        return "Sorry, I couldn't find any content to answer your question."
+        yield "Sorry, I couldn't find any content to answer your question."
+        return
 
     local_index = VectorStoreIndex(nodes=nodes)
     retriever = local_index.as_retriever(
diff --git a/src/paperless/tests/test_ai_chat.py b/src/paperless/tests/test_ai_chat.py
new file mode 100644 (file)
index 0000000..2b792c4
--- /dev/null
@@ -0,0 +1,142 @@
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import pytest
+from llama_index.core import VectorStoreIndex
+from llama_index.core.schema import TextNode
+
+from paperless.ai.chat import stream_chat_with_documents
+
+
+@pytest.fixture(autouse=True)
+def patch_embed_model():
+    from llama_index.core import settings as llama_settings
+
+    mock_embed_model = MagicMock()
+    mock_embed_model._get_text_embedding_batch.return_value = [
+        [0.1] * 1536,
+    ]  # 1 vector per input
+    llama_settings.Settings._embed_model = mock_embed_model
+    yield
+    llama_settings.Settings._embed_model = None
+
+
+@pytest.fixture(autouse=True)
+def patch_embed_nodes():
+    with patch(
+        "llama_index.core.indices.vector_store.base.embed_nodes",
+    ) as mock_embed_nodes:
+        mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
+            node.node_id: [0.1] * 1536 for node in nodes
+        }
+        yield
+
+
+@pytest.fixture
+def mock_document():
+    doc = MagicMock()
+    doc.pk = 1
+    doc.title = "Test Document"
+    doc.filename = "test_file.pdf"
+    doc.content = "This is the document content."
+    return doc
+
+
+def test_stream_chat_with_one_document_full_content(mock_document):
+    with (
+        patch("paperless.ai.chat.AIClient") as mock_client_cls,
+        patch("paperless.ai.chat.load_index") as mock_load_index,
+        patch(
+            "paperless.ai.chat.RetrieverQueryEngine.from_args",
+        ) as mock_query_engine_cls,
+    ):
+        mock_client = MagicMock()
+        mock_client_cls.return_value = mock_client
+        mock_client.llm = MagicMock()
+
+        mock_node = TextNode(
+            text="This is node content.",
+            metadata={"document_id": mock_document.pk, "title": "Test Document"},
+        )
+        mock_index = MagicMock()
+        mock_index.docstore.docs.values.return_value = [mock_node]
+        mock_load_index.return_value = mock_index
+
+        mock_response_stream = MagicMock()
+        mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
+        mock_query_engine = MagicMock()
+        mock_query_engine_cls.return_value = mock_query_engine
+        mock_query_engine.query.return_value = mock_response_stream
+
+        output = list(stream_chat_with_documents("What is this?", [mock_document]))
+
+        assert output == ["chunk1", "chunk2"]
+
+
+def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
+    with (
+        patch("paperless.ai.chat.AIClient") as mock_client_cls,
+        patch("paperless.ai.chat.load_index") as mock_load_index,
+        patch(
+            "paperless.ai.chat.RetrieverQueryEngine.from_args",
+        ) as mock_query_engine_cls,
+        patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
+    ):
+        # Mock AIClient and LLM
+        mock_client = MagicMock()
+        mock_client_cls.return_value = mock_client
+        mock_client.llm = MagicMock()
+
+        # Create two real TextNodes
+        mock_node1 = TextNode(
+            text="Content for doc 1.",
+            metadata={"document_id": 1, "title": "Document 1"},
+        )
+        mock_node2 = TextNode(
+            text="Content for doc 2.",
+            metadata={"document_id": 2, "title": "Document 2"},
+        )
+        mock_index = MagicMock()
+        mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
+        mock_load_index.return_value = mock_index
+
+        # Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
+        mock_retriever = MagicMock()
+        mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
+        mock_as_retriever.return_value = mock_retriever
+
+        # Mock response stream
+        mock_response_stream = MagicMock()
+        mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
+
+        # Mock RetrieverQueryEngine
+        mock_query_engine = MagicMock()
+        mock_query_engine_cls.return_value = mock_query_engine
+        mock_query_engine.query.return_value = mock_response_stream
+
+        # Fake documents
+        doc1 = MagicMock(pk=1)
+        doc2 = MagicMock(pk=2)
+
+        output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
+
+        assert output == ["chunk1", "chunk2"]
+
+
+def test_stream_chat_no_matching_nodes():
+    with (
+        patch("paperless.ai.chat.AIClient") as mock_client_cls,
+        patch("paperless.ai.chat.load_index") as mock_load_index,
+    ):
+        mock_client = MagicMock()
+        mock_client_cls.return_value = mock_client
+        mock_client.llm = MagicMock()
+
+        mock_index = MagicMock()
+        # No matching nodes
+        mock_index.docstore.docs.values.return_value = []
+        mock_load_index.return_value = mock_index
+
+        output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
+
+        assert output == ["Sorry, I couldn't find any content to answer your question."]