]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Fixup some tests
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 07:59:46 +0000 (00:59 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:56 +0000 (11:01 -0700)
src/documents/tests/test_api_app_config.py
src/paperless/ai/llms.py
src/paperless/settings.py
src/paperless/tests/test_ai_classifier.py
src/paperless/tests/test_ai_client.py

index 502a22fcd14b1ca17808eb91c23b174acc9e729d..e87802cf05a10e09268c8091552c970324deb38d 100644 (file)
@@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
                 "barcode_enable_tag": None,
                 "barcode_tag_mapping": None,
                 "ai_enabled": False,
+                "llm_embedding_backend": None,
+                "llm_embedding_model": None,
                 "llm_backend": None,
                 "llm_model": None,
                 "llm_api_key": None,
index b51045d451718db109e5b6cad29e09ff5eb1df3a..c4b56f36da621ac3a86ba705f86c06e2559133f5 100644 (file)
@@ -37,28 +37,65 @@ class OllamaLLM(LLM):
             data = response.json()
             return CompletionResponse(text=data["response"])
 
+    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+        with httpx.Client(timeout=120.0) as client:
+            response = client.post(
+                f"{self.base_url}/api/generate",
+                json={
+                    "model": self.model,
+                    "messages": [
+                        {
+                            "role": message.role,
+                            "content": message.content,
+                        }
+                        for message in messages
+                    ],
+                    "stream": False,
+                },
+            )
+            response.raise_for_status()
+            data = response.json()
+            return ChatResponse(text=data["response"])
+
     # -- Required stubs for ABC:
-    def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+    def stream_complete(
+        self,
+        prompt: str,
+        **kwargs,
+    ) -> CompletionResponseGen:  # pragma: no cover
         raise NotImplementedError("stream_complete not supported")
 
-    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
-        raise NotImplementedError("chat not supported")
-
-    def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
+    def stream_chat(
+        self,
+        messages: list[ChatMessage],
+        **kwargs,
+    ) -> ChatResponseGen:  # pragma: no cover
         raise NotImplementedError("stream_chat not supported")
 
-    async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+    async def achat(
+        self,
+        messages: list[ChatMessage],
+        **kwargs,
+    ) -> ChatResponse:  # pragma: no cover
         raise NotImplementedError("async chat not supported")
 
     async def astream_chat(
         self,
         messages: list[ChatMessage],
         **kwargs,
-    ) -> ChatResponseGen:
+    ) -> ChatResponseGen:  # pragma: no cover
         raise NotImplementedError("async stream_chat not supported")
 
-    async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
+    async def acomplete(
+        self,
+        prompt: str,
+        **kwargs,
+    ) -> CompletionResponse:  # pragma: no cover
         raise NotImplementedError("async complete not supported")
 
-    async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+    async def astream_complete(
+        self,
+        prompt: str,
+        **kwargs,
+    ) -> CompletionResponseGen:  # pragma: no cover
         raise NotImplementedError("async stream_complete not supported")
index 98b65c7dcdf9ed11cb7c76f6d589d2eacbbb44ea..fb8b161d44e7b27ff074ae1a80d9bb564e9bc1f2 100644 (file)
@@ -1419,10 +1419,9 @@ OUTLOOK_OAUTH_ENABLED = bool(
 AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
 LLM_EMBEDDING_BACKEND = os.getenv(
     "PAPERLESS_LLM_EMBEDDING_BACKEND",
-    "local",
-)  # or "openai"
+)  # "local" or "openai"
 LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
-LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama")  # or "openai"
+LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND")  # "ollama" or "openai"
 LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
 LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
 LLM_URL = os.getenv("PAPERLESS_LLM_URL")
index edb086bbee297f0607c6d56d0cdbdce2c81654e2..a473652fc2078e31be51a95f423546f1020b6781 100644 (file)
@@ -1,11 +1,13 @@
 import json
+from unittest.mock import MagicMock
 from unittest.mock import patch
 
 import pytest
+from django.test import override_settings
 
 from documents.models import Document
 from paperless.ai.ai_classifier import get_ai_document_classification
-from paperless.ai.ai_classifier import parse_ai_classification_response
+from paperless.ai.ai_classifier import parse_ai_response
 
 
 @pytest.fixture
@@ -15,8 +17,12 @@ def mock_document():
 
 @pytest.mark.django_db
 @patch("paperless.ai.client.AIClient.run_llm_query")
+@override_settings(
+    LLM_BACKEND="ollama",
+    LLM_MODEL="some_model",
+)
 def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
-    mock_response = json.dumps(
+    mock_run_llm_query.return_value.text = json.dumps(
         {
             "title": "Test Title",
             "tags": ["test", "document"],
@@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
             "dates": ["2023-01-01"],
         },
     )
-    mock_run_llm_query.return_value = mock_response
 
     result = get_ai_document_classification(mock_document)
 
@@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
 def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
     mock_run_llm_query.side_effect = Exception("LLM query failed")
 
-    result = get_ai_document_classification(mock_document)
-
-    assert result == {}
-
-
-def test_parse_llm_classification_response_valid():
-    mock_response = json.dumps(
-        {
-            "title": "Test Title",
-            "tags": ["test", "document"],
-            "correspondents": ["John Doe"],
-            "document_types": ["report"],
-            "storage_paths": ["Reports"],
-            "dates": ["2023-01-01"],
-        },
-    )
-
-    result = parse_ai_classification_response(mock_response)
-
-    assert result["title"] == "Test Title"
-    assert result["tags"] == ["test", "document"]
-    assert result["correspondents"] == ["John Doe"]
-    assert result["document_types"] == ["report"]
-    assert result["storage_paths"] == ["Reports"]
-    assert result["dates"] == ["2023-01-01"]
+    # assert raises an exception
+    with pytest.raises(Exception):
+        get_ai_document_classification(mock_document)
 
 
 def test_parse_llm_classification_response_invalid_json():
-    mock_response = "Invalid JSON"
+    mock_response = MagicMock()
+    mock_response.text = "Invalid JSON response"
 
-    result = parse_ai_classification_response(mock_response)
+    result = parse_ai_response(mock_response)
 
     assert result == {}
 
 
-def test_parse_llm_classification_response_partial_data():
-    mock_response = json.dumps(
-        {
-            "title": "Partial Data",
-            "tags": ["partial"],
-            "correspondents": "Jane Doe",
-            "document_types": "note",
-            "storage_paths": [],
-            "dates": [],
-        },
-    )
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
+@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
+@override_settings(
+    LLM_EMBEDDING_BACKEND="local",
+    LLM_EMBEDDING_MODEL="some_model",
+    LLM_BACKEND="ollama",
+    LLM_MODEL="some_model",
+)
+def test_use_rag_if_configured(
+    mock_build_prompt_with_rag,
+    mock_run_llm_query,
+    mock_document,
+):
+    mock_build_prompt_with_rag.return_value = "Prompt with RAG"
+    mock_run_llm_query.return_value.text = json.dumps({})
+    get_ai_document_classification(mock_document)
+    mock_build_prompt_with_rag.assert_called_once()
 
-    result = parse_ai_classification_response(mock_response)
 
-    assert result["title"] == "Partial Data"
-    assert result["tags"] == ["partial"]
-    assert result["correspondents"] == ["Jane Doe"]
-    assert result["document_types"] == ["note"]
-    assert result["storage_paths"] == []
-    assert result["dates"] == []
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
+@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
+@patch("paperless.config.AIConfig")
+@override_settings(
+    LLM_BACKEND="ollama",
+    LLM_MODEL="some_model",
+)
+def test_use_without_rag_if_not_configured(
+    mock_ai_config,
+    mock_build_prompt_without_rag,
+    mock_run_llm_query,
+    mock_document,
+):
+    mock_ai_config.llm_embedding_backend = None
+    mock_build_prompt_without_rag.return_value = "Prompt without RAG"
+    mock_run_llm_query.return_value.text = json.dumps({})
+    get_ai_document_classification(mock_document)
+    mock_build_prompt_without_rag.assert_called_once()
index 6a239279ec00b7e720ba8b43c26c948434303296..27b160d23ae6bd6f264c60667c4e9eec2cd8fb37 100644 (file)
@@ -1,95 +1,93 @@
-import json
+from unittest.mock import MagicMock
 from unittest.mock import patch
 
 import pytest
-from django.conf import settings
+from llama_index.core.llms import ChatMessage
 
 from paperless.ai.client import AIClient
 
 
 @pytest.fixture
-def mock_settings():
-    settings.LLM_BACKEND = "openai"
-    settings.LLM_MODEL = "gpt-3.5-turbo"
-    settings.LLM_API_KEY = "test-api-key"
-    yield settings
-
-
-@pytest.mark.django_db
-@patch("paperless.ai.client.AIClient._run_openai_query")
-@patch("paperless.ai.client.AIClient._run_ollama_query")
-def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
-    mock_settings.LLM_BACKEND = "openai"
-    mock_openai_query.return_value = "OpenAI response"
-    client = AIClient()
-    result = client.run_llm_query("Test prompt")
-    assert result == "OpenAI response"
-    mock_openai_query.assert_called_once_with("Test prompt")
-    mock_ollama_query.assert_not_called()
-
-
-@pytest.mark.django_db
-@patch("paperless.ai.client.AIClient._run_openai_query")
-@patch("paperless.ai.client.AIClient._run_ollama_query")
-def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
-    mock_settings.LLM_BACKEND = "ollama"
-    mock_ollama_query.return_value = "Ollama response"
+def mock_ai_config():
+    with patch("paperless.ai.client.AIConfig") as MockAIConfig:
+        mock_config = MagicMock()
+        MockAIConfig.return_value = mock_config
+        yield mock_config
+
+
+@pytest.fixture
+def mock_ollama_llm():
+    with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
+        yield MockOllamaLLM
+
+
+@pytest.fixture
+def mock_openai_llm():
+    with patch("paperless.ai.client.OpenAI") as MockOpenAI:
+        yield MockOpenAI
+
+
+def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
+    mock_ai_config.llm_backend = "ollama"
+    mock_ai_config.llm_model = "test_model"
+    mock_ai_config.llm_url = "http://test-url"
+
     client = AIClient()
-    result = client.run_llm_query("Test prompt")
-    assert result == "Ollama response"
-    mock_ollama_query.assert_called_once_with("Test prompt")
-    mock_openai_query.assert_not_called()
 
+    mock_ollama_llm.assert_called_once_with(
+        model="test_model",
+        base_url="http://test-url",
+    )
+    assert client.llm == mock_ollama_llm.return_value
+
+
+def test_get_llm_openai(mock_ai_config, mock_openai_llm):
+    mock_ai_config.llm_backend = "openai"
+    mock_ai_config.llm_model = "test_model"
+    mock_ai_config.openai_api_key = "test_api_key"
 
-@pytest.mark.django_db
-def test_run_llm_query_unsupported_backend(mock_settings):
-    mock_settings.LLM_BACKEND = "unsupported"
     client = AIClient()
+
+    mock_openai_llm.assert_called_once_with(
+        model="test_model",
+        api_key="test_api_key",
+    )
+    assert client.llm == mock_openai_llm.return_value
+
+
+def test_get_llm_unsupported_backend(mock_ai_config):
+    mock_ai_config.llm_backend = "unsupported"
+
     with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
-        client.run_llm_query("Test prompt")
+        AIClient()
 
 
-@pytest.mark.django_db
-def test_run_openai_query(httpx_mock, mock_settings):
-    mock_settings.LLM_BACKEND = "openai"
-    httpx_mock.add_response(
-        url="https://api.openai.com/v1/chat/completions",
-        json={
-            "choices": [{"message": {"content": "OpenAI response"}}],
-        },
-    )
+def test_run_llm_query(mock_ai_config, mock_ollama_llm):
+    mock_ai_config.llm_backend = "ollama"
+    mock_ai_config.llm_model = "test_model"
+    mock_ai_config.llm_url = "http://test-url"
+
+    mock_llm_instance = mock_ollama_llm.return_value
+    mock_llm_instance.complete.return_value = "test_result"
 
     client = AIClient()
-    result = client.run_llm_query("Test prompt")
-    assert result == "OpenAI response"
-
-    request = httpx_mock.get_request()
-    assert request.method == "POST"
-    assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
-    assert request.headers["Content-Type"] == "application/json"
-    assert json.loads(request.content) == {
-        "model": mock_settings.LLM_MODEL,
-        "messages": [{"role": "user", "content": "Test prompt"}],
-        "temperature": 0.3,
-    }
-
-
-@pytest.mark.django_db
-def test_run_ollama_query(httpx_mock, mock_settings):
-    mock_settings.LLM_BACKEND = "ollama"
-    httpx_mock.add_response(
-        url="http://localhost:11434/api/chat",
-        json={"message": {"content": "Ollama response"}},
-    )
+    result = client.run_llm_query("test_prompt")
+
+    mock_llm_instance.complete.assert_called_once_with("test_prompt")
+    assert result == "test_result"
+
+
+def test_run_chat(mock_ai_config, mock_ollama_llm):
+    mock_ai_config.llm_backend = "ollama"
+    mock_ai_config.llm_model = "test_model"
+    mock_ai_config.llm_url = "http://test-url"
+
+    mock_llm_instance = mock_ollama_llm.return_value
+    mock_llm_instance.chat.return_value = "test_chat_result"
 
     client = AIClient()
-    result = client.run_llm_query("Test prompt")
-    assert result == "Ollama response"
-
-    request = httpx_mock.get_request()
-    assert request.method == "POST"
-    assert json.loads(request.content) == {
-        "model": mock_settings.LLM_MODEL,
-        "messages": [{"role": "user", "content": "Test prompt"}],
-        "stream": False,
-    }
+    messages = [ChatMessage(role="user", content="Hello")]
+    result = client.run_chat(messages)
+
+    mock_llm_instance.chat.assert_called_once_with(messages)
+    assert result == "test_chat_result"