]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Fix tests for change to structured output
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 15 Jul 2025 21:34:54 +0000 (14:34 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Tue, 15 Jul 2025 21:34:54 +0000 (14:34 -0700)
src/paperless_ai/tests/test_ai_classifier.py
src/paperless_ai/tests/test_client.py

index 548acbc6c70237129569e2a0583eae66288199ce..115d51cd4bca558d70cebb4abb1c7bfe9246701e 100644 (file)
@@ -10,7 +10,6 @@ from paperless_ai.ai_classifier import build_prompt_with_rag
 from paperless_ai.ai_classifier import build_prompt_without_rag
 from paperless_ai.ai_classifier import get_ai_document_classification
 from paperless_ai.ai_classifier import get_context_for_document
-from paperless_ai.ai_classifier import parse_ai_response
 
 
 @pytest.fixture
@@ -75,50 +74,14 @@ def mock_similar_documents():
     LLM_MODEL="some_model",
 )
 def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
-    mock_run_llm_query.return_value.text = json.dumps(
-        {
-            "title": "Test Title",
-            "tags": ["test", "document"],
-            "correspondents": ["John Doe"],
-            "document_types": ["report"],
-            "storage_paths": ["Reports"],
-            "dates": ["2023-01-01"],
-        },
-    )
-
-    result = get_ai_document_classification(mock_document)
-
-    assert result["title"] == "Test Title"
-    assert result["tags"] == ["test", "document"]
-    assert result["correspondents"] == ["John Doe"]
-    assert result["document_types"] == ["report"]
-    assert result["storage_paths"] == ["Reports"]
-    assert result["dates"] == ["2023-01-01"]
-
-
-@pytest.mark.django_db
-@patch("paperless_ai.client.AIClient.run_llm_query")
-@override_settings(
-    LLM_BACKEND="ollama",
-    LLM_MODEL="some_model",
-)
-def test_get_ai_document_classification_fallback_parse_success(
-    mock_run_llm_query,
-    mock_document,
-):
-    mock_run_llm_query.return_value.text = """
-    There is some text before the JSON.
-    ```json
-    {
+    mock_run_llm_query.return_value = {
         "title": "Test Title",
         "tags": ["test", "document"],
         "correspondents": ["John Doe"],
         "document_types": ["report"],
         "storage_paths": ["Reports"],
-        "dates": ["2023-01-01"]
+        "dates": ["2023-01-01"],
     }
-    ```
-    """
 
     result = get_ai_document_classification(mock_document)
 
@@ -130,22 +93,6 @@ def test_get_ai_document_classification_fallback_parse_success(
     assert result["dates"] == ["2023-01-01"]
 
 
-@pytest.mark.django_db
-@patch("paperless_ai.client.AIClient.run_llm_query")
-@override_settings(
-    LLM_BACKEND="ollama",
-    LLM_MODEL="some_model",
-)
-def test_get_ai_document_classification_parse_failure(
-    mock_run_llm_query,
-    mock_document,
-):
-    mock_run_llm_query.return_value.text = "Invalid JSON response"
-
-    result = get_ai_document_classification(mock_document)
-    assert result == {}
-
-
 @pytest.mark.django_db
 @patch("paperless_ai.client.AIClient.run_llm_query")
 def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@@ -156,15 +103,6 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
         get_ai_document_classification(mock_document)
 
 
-def test_parse_llm_classification_response_invalid_json():
-    mock_response = MagicMock()
-    mock_response.text = "Invalid JSON response"
-
-    result = parse_ai_response(mock_response)
-
-    assert result == {}
-
-
 @pytest.mark.django_db
 @patch("paperless_ai.client.AIClient.run_llm_query")
 @patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@@ -218,10 +156,10 @@ def test_prompt_with_without_rag(mock_document):
         return_value="Context from similar documents",
     ):
         prompt = build_prompt_without_rag(mock_document)
-        assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
+        assert "Additional context from similar documents:" not in prompt
 
         prompt = build_prompt_with_rag(mock_document)
-        assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
+        assert "Additional context from similar documents:" in prompt
 
 
 @patch("paperless_ai.ai_classifier.query_similar_documents")
index 7cd2b16b07ba6821d9bd8118717d76d0d6714e3a..6ef7b332b44b77327cc9e67cf21cf26e904fa1d8 100644 (file)
@@ -3,6 +3,7 @@ from unittest.mock import patch
 
 import pytest
 from llama_index.core.llms import ChatMessage
+from llama_index.core.llms.llm import ToolSelection
 
 from paperless_ai.client import AIClient
 
@@ -69,13 +70,27 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
     mock_ai_config.llm_url = "http://test-url"
 
     mock_llm_instance = mock_ollama_llm.return_value
-    mock_llm_instance.complete.return_value = "test_result"
+
+    tool_selection = ToolSelection(
+        tool_id="call_test",
+        tool_name="DocumentClassifierSchema",
+        tool_kwargs={
+            "title": "Test Title",
+            "tags": ["test", "document"],
+            "correspondents": ["John Doe"],
+            "document_types": ["report"],
+            "storage_paths": ["Reports"],
+            "dates": ["2023-01-01"],
+        },
+    )
+
+    mock_llm_instance.chat_with_tools.return_value = MagicMock()
+    mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
 
     client = AIClient()
     result = client.run_llm_query("test_prompt")
 
-    mock_llm_instance.complete.assert_called_once_with("test_prompt")
-    assert result == "test_result"
+    assert result["title"] == "Test Title"
 
 
 def test_run_chat(mock_ai_config, mock_ollama_llm):