from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
-from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
- mock_run_llm_query.return_value.text = json.dumps(
- {
- "title": "Test Title",
- "tags": ["test", "document"],
- "correspondents": ["John Doe"],
- "document_types": ["report"],
- "storage_paths": ["Reports"],
- "dates": ["2023-01-01"],
- },
- )
-
- result = get_ai_document_classification(mock_document)
-
- assert result["title"] == "Test Title"
- assert result["tags"] == ["test", "document"]
- assert result["correspondents"] == ["John Doe"]
- assert result["document_types"] == ["report"]
- assert result["storage_paths"] == ["Reports"]
- assert result["dates"] == ["2023-01-01"]
-
-
-@pytest.mark.django_db
-@patch("paperless_ai.client.AIClient.run_llm_query")
-@override_settings(
- LLM_BACKEND="ollama",
- LLM_MODEL="some_model",
-)
-def test_get_ai_document_classification_fallback_parse_success(
- mock_run_llm_query,
- mock_document,
-):
- mock_run_llm_query.return_value.text = """
- There is some text before the JSON.
- ```json
- {
+ mock_run_llm_query.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
- "dates": ["2023-01-01"]
+ "dates": ["2023-01-01"],
}
- ```
- """
result = get_ai_document_classification(mock_document)
assert result["dates"] == ["2023-01-01"]
-@pytest.mark.django_db
-@patch("paperless_ai.client.AIClient.run_llm_query")
-@override_settings(
- LLM_BACKEND="ollama",
- LLM_MODEL="some_model",
-)
-def test_get_ai_document_classification_parse_failure(
- mock_run_llm_query,
- mock_document,
-):
- mock_run_llm_query.return_value.text = "Invalid JSON response"
-
- result = get_ai_document_classification(mock_document)
- assert result == {}
-
-
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
get_ai_document_classification(mock_document)
-def test_parse_llm_classification_response_invalid_json():
- mock_response = MagicMock()
- mock_response.text = "Invalid JSON response"
-
- result = parse_ai_response(mock_response)
-
- assert result == {}
-
-
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
- assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
+ assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document)
- assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
+ assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents")
import pytest
from llama_index.core.llms import ChatMessage
+from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import AIClient
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
- mock_llm_instance.complete.return_value = "test_result"
+
+ tool_selection = ToolSelection(
+ tool_id="call_test",
+ tool_name="DocumentClassifierSchema",
+ tool_kwargs={
+ "title": "Test Title",
+ "tags": ["test", "document"],
+ "correspondents": ["John Doe"],
+ "document_types": ["report"],
+ "storage_paths": ["Reports"],
+ "dates": ["2023-01-01"],
+ },
+ )
+
+ mock_llm_instance.chat_with_tools.return_value = MagicMock()
+ mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
client = AIClient()
result = client.run_llm_query("test_prompt")
- mock_llm_instance.complete.assert_called_once_with("test_prompt")
- assert result == "test_result"
+ assert result["title"] == "Test Title"
def test_run_chat(mock_ai_config, mock_ollama_llm):