"barcode_enable_tag": None,
"barcode_tag_mapping": None,
"ai_enabled": False,
+ "llm_embedding_backend": None,
+ "llm_embedding_model": None,
"llm_backend": None,
"llm_model": None,
"llm_api_key": None,
data = response.json()
return CompletionResponse(text=data["response"])
+ def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+ with httpx.Client(timeout=120.0) as client:
+ response = client.post(
+ f"{self.base_url}/api/generate",
+ json={
+ "model": self.model,
+ "messages": [
+ {
+ "role": message.role,
+ "content": message.content,
+ }
+ for message in messages
+ ],
+ "stream": False,
+ },
+ )
+ response.raise_for_status()
+ data = response.json()
+ return ChatResponse(text=data["response"])
+
# -- Required stubs for ABC:
- def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+ def stream_complete(
+ self,
+ prompt: str,
+ **kwargs,
+ ) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("stream_complete not supported")
- def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
- raise NotImplementedError("chat not supported")
-
- def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
+ def stream_chat(
+ self,
+ messages: list[ChatMessage],
+ **kwargs,
+ ) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("stream_chat not supported")
- async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+ async def achat(
+ self,
+ messages: list[ChatMessage],
+ **kwargs,
+ ) -> ChatResponse: # pragma: no cover
raise NotImplementedError("async chat not supported")
async def astream_chat(
self,
messages: list[ChatMessage],
**kwargs,
- ) -> ChatResponseGen:
+ ) -> ChatResponseGen: # pragma: no cover
raise NotImplementedError("async stream_chat not supported")
- async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
+ async def acomplete(
+ self,
+ prompt: str,
+ **kwargs,
+ ) -> CompletionResponse: # pragma: no cover
raise NotImplementedError("async complete not supported")
- async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+ async def astream_complete(
+ self,
+ prompt: str,
+ **kwargs,
+ ) -> CompletionResponseGen: # pragma: no cover
raise NotImplementedError("async stream_complete not supported")
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
LLM_EMBEDDING_BACKEND = os.getenv(
"PAPERLESS_LLM_EMBEDDING_BACKEND",
- "local",
-) # or "openai"
+) # "local" or "openai"
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
-LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama") # or "openai"
+LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai"
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
import json
+from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
+from django.test import override_settings
from documents.models import Document
from paperless.ai.ai_classifier import get_ai_document_classification
-from paperless.ai.ai_classifier import parse_ai_classification_response
+from paperless.ai.ai_classifier import parse_ai_response
@pytest.fixture
@pytest.mark.django_db
@patch("paperless.ai.client.AIClient.run_llm_query")
+@override_settings(
+ LLM_BACKEND="ollama",
+ LLM_MODEL="some_model",
+)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
- mock_response = json.dumps(
+ mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"dates": ["2023-01-01"],
},
)
- mock_run_llm_query.return_value = mock_response
result = get_ai_document_classification(mock_document)
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed")
- result = get_ai_document_classification(mock_document)
-
- assert result == {}
-
-
-def test_parse_llm_classification_response_valid():
- mock_response = json.dumps(
- {
- "title": "Test Title",
- "tags": ["test", "document"],
- "correspondents": ["John Doe"],
- "document_types": ["report"],
- "storage_paths": ["Reports"],
- "dates": ["2023-01-01"],
- },
- )
-
- result = parse_ai_classification_response(mock_response)
-
- assert result["title"] == "Test Title"
- assert result["tags"] == ["test", "document"]
- assert result["correspondents"] == ["John Doe"]
- assert result["document_types"] == ["report"]
- assert result["storage_paths"] == ["Reports"]
- assert result["dates"] == ["2023-01-01"]
+ # assert raises an exception
+ with pytest.raises(Exception):
+ get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
- mock_response = "Invalid JSON"
+ mock_response = MagicMock()
+ mock_response.text = "Invalid JSON response"
- result = parse_ai_classification_response(mock_response)
+ result = parse_ai_response(mock_response)
assert result == {}
-def test_parse_llm_classification_response_partial_data():
- mock_response = json.dumps(
- {
- "title": "Partial Data",
- "tags": ["partial"],
- "correspondents": "Jane Doe",
- "document_types": "note",
- "storage_paths": [],
- "dates": [],
- },
- )
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
+@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
+@override_settings(
+ LLM_EMBEDDING_BACKEND="local",
+ LLM_EMBEDDING_MODEL="some_model",
+ LLM_BACKEND="ollama",
+ LLM_MODEL="some_model",
+)
+def test_use_rag_if_configured(
+ mock_build_prompt_with_rag,
+ mock_run_llm_query,
+ mock_document,
+):
+ mock_build_prompt_with_rag.return_value = "Prompt with RAG"
+ mock_run_llm_query.return_value.text = json.dumps({})
+ get_ai_document_classification(mock_document)
+ mock_build_prompt_with_rag.assert_called_once()
- result = parse_ai_classification_response(mock_response)
- assert result["title"] == "Partial Data"
- assert result["tags"] == ["partial"]
- assert result["correspondents"] == ["Jane Doe"]
- assert result["document_types"] == ["note"]
- assert result["storage_paths"] == []
- assert result["dates"] == []
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
+@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
+@patch("paperless.config.AIConfig")
+@override_settings(
+ LLM_BACKEND="ollama",
+ LLM_MODEL="some_model",
+)
+def test_use_without_rag_if_not_configured(
+ mock_ai_config,
+ mock_build_prompt_without_rag,
+ mock_run_llm_query,
+ mock_document,
+):
+ mock_ai_config.llm_embedding_backend = None
+ mock_build_prompt_without_rag.return_value = "Prompt without RAG"
+ mock_run_llm_query.return_value.text = json.dumps({})
+ get_ai_document_classification(mock_document)
+ mock_build_prompt_without_rag.assert_called_once()
-import json
+from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
-from django.conf import settings
+from llama_index.core.llms import ChatMessage
from paperless.ai.client import AIClient
@pytest.fixture
-def mock_settings():
- settings.LLM_BACKEND = "openai"
- settings.LLM_MODEL = "gpt-3.5-turbo"
- settings.LLM_API_KEY = "test-api-key"
- yield settings
-
-
-@pytest.mark.django_db
-@patch("paperless.ai.client.AIClient._run_openai_query")
-@patch("paperless.ai.client.AIClient._run_ollama_query")
-def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
- mock_settings.LLM_BACKEND = "openai"
- mock_openai_query.return_value = "OpenAI response"
- client = AIClient()
- result = client.run_llm_query("Test prompt")
- assert result == "OpenAI response"
- mock_openai_query.assert_called_once_with("Test prompt")
- mock_ollama_query.assert_not_called()
-
-
-@pytest.mark.django_db
-@patch("paperless.ai.client.AIClient._run_openai_query")
-@patch("paperless.ai.client.AIClient._run_ollama_query")
-def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
- mock_settings.LLM_BACKEND = "ollama"
- mock_ollama_query.return_value = "Ollama response"
+def mock_ai_config():
+ with patch("paperless.ai.client.AIConfig") as MockAIConfig:
+ mock_config = MagicMock()
+ MockAIConfig.return_value = mock_config
+ yield mock_config
+
+
+@pytest.fixture
+def mock_ollama_llm():
+ with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
+ yield MockOllamaLLM
+
+
+@pytest.fixture
+def mock_openai_llm():
+ with patch("paperless.ai.client.OpenAI") as MockOpenAI:
+ yield MockOpenAI
+
+
+def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
+ mock_ai_config.llm_backend = "ollama"
+ mock_ai_config.llm_model = "test_model"
+ mock_ai_config.llm_url = "http://test-url"
+
client = AIClient()
- result = client.run_llm_query("Test prompt")
- assert result == "Ollama response"
- mock_ollama_query.assert_called_once_with("Test prompt")
- mock_openai_query.assert_not_called()
+ mock_ollama_llm.assert_called_once_with(
+ model="test_model",
+ base_url="http://test-url",
+ )
+ assert client.llm == mock_ollama_llm.return_value
+
+
+def test_get_llm_openai(mock_ai_config, mock_openai_llm):
+ mock_ai_config.llm_backend = "openai"
+ mock_ai_config.llm_model = "test_model"
+ mock_ai_config.openai_api_key = "test_api_key"
-@pytest.mark.django_db
-def test_run_llm_query_unsupported_backend(mock_settings):
- mock_settings.LLM_BACKEND = "unsupported"
client = AIClient()
+
+ mock_openai_llm.assert_called_once_with(
+ model="test_model",
+ api_key="test_api_key",
+ )
+ assert client.llm == mock_openai_llm.return_value
+
+
+def test_get_llm_unsupported_backend(mock_ai_config):
+ mock_ai_config.llm_backend = "unsupported"
+
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
- client.run_llm_query("Test prompt")
+ AIClient()
-@pytest.mark.django_db
-def test_run_openai_query(httpx_mock, mock_settings):
- mock_settings.LLM_BACKEND = "openai"
- httpx_mock.add_response(
- url="https://api.openai.com/v1/chat/completions",
- json={
- "choices": [{"message": {"content": "OpenAI response"}}],
- },
- )
+def test_run_llm_query(mock_ai_config, mock_ollama_llm):
+ mock_ai_config.llm_backend = "ollama"
+ mock_ai_config.llm_model = "test_model"
+ mock_ai_config.llm_url = "http://test-url"
+
+ mock_llm_instance = mock_ollama_llm.return_value
+ mock_llm_instance.complete.return_value = "test_result"
client = AIClient()
- result = client.run_llm_query("Test prompt")
- assert result == "OpenAI response"
-
- request = httpx_mock.get_request()
- assert request.method == "POST"
- assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
- assert request.headers["Content-Type"] == "application/json"
- assert json.loads(request.content) == {
- "model": mock_settings.LLM_MODEL,
- "messages": [{"role": "user", "content": "Test prompt"}],
- "temperature": 0.3,
- }
-
-
-@pytest.mark.django_db
-def test_run_ollama_query(httpx_mock, mock_settings):
- mock_settings.LLM_BACKEND = "ollama"
- httpx_mock.add_response(
- url="http://localhost:11434/api/chat",
- json={"message": {"content": "Ollama response"}},
- )
+ result = client.run_llm_query("test_prompt")
+
+ mock_llm_instance.complete.assert_called_once_with("test_prompt")
+ assert result == "test_result"
+
+
+def test_run_chat(mock_ai_config, mock_ollama_llm):
+ mock_ai_config.llm_backend = "ollama"
+ mock_ai_config.llm_model = "test_model"
+ mock_ai_config.llm_url = "http://test-url"
+
+ mock_llm_instance = mock_ollama_llm.return_value
+ mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
- result = client.run_llm_query("Test prompt")
- assert result == "Ollama response"
-
- request = httpx_mock.get_request()
- assert request.method == "POST"
- assert json.loads(request.content) == {
- "model": mock_settings.LLM_MODEL,
- "messages": [{"role": "user", "content": "Test prompt"}],
- "stream": False,
- }
+ messages = [ChatMessage(role="user", content="Hello")]
+ result = client.run_chat(messages)
+
+ mock_llm_instance.chat.assert_called_once_with(messages)
+ assert result == "test_chat_result"