]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Add fallback parsing for invalid ai responses
authorshamoon <4887959+shamoon@users.noreply.github.com>
Wed, 30 Apr 2025 20:03:31 +0000 (13:03 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:04:59 +0000 (11:04 -0700)
src/paperless_ai/ai_classifier.py
src/paperless_ai/tests/test_ai_classifier.py

index b75ceb1e5c61f8f4436923d930628dbc2a57db46..55c7c77046a3841082f558ee2dc1f6617f54508c 100644 (file)
@@ -16,7 +16,7 @@ logger = logging.getLogger("paperless_ai.rag_classifier")
 
 def build_prompt_without_rag(document: Document) -> str:
     filename = document.filename or ""
-    content = truncate_content(document.content or "")
+    content = truncate_content(document.content[:4000] or "")
 
     prompt = f"""
     You are an assistant that extracts structured information from documents.
@@ -43,7 +43,7 @@ def build_prompt_without_rag(document: Document) -> str:
         "storage_paths": ["xxxx", "xxxx"],
         "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
     }}
-    ---
+    ---------
 
     FILENAME:
     {filename}
@@ -63,6 +63,10 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
 
     CONTEXT FROM SIMILAR DOCUMENTS:
     {context}
+
+    ---------
+
+    DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
     """
 
     return prompt
@@ -108,8 +112,24 @@ def parse_ai_response(response: CompletionResponse) -> dict:
             "dates": raw.get("dates", []),
         }
     except json.JSONDecodeError:
-        logger.exception("Invalid JSON in AI response")
-        return {}
+        logger.warning("Invalid JSON in AI response, attempting modified parsing...")
+        try:
+            # search for a valid json string like { ... } in the response
+            start = response.text.index("{")
+            end = response.text.rindex("}") + 1
+            json_str = response.text[start:end]
+            raw = json.loads(json_str)
+            return {
+                "title": raw.get("title"),
+                "tags": raw.get("tags", []),
+                "correspondents": raw.get("correspondents", []),
+                "document_types": raw.get("document_types", []),
+                "storage_paths": raw.get("storage_paths", []),
+                "dates": raw.get("dates", []),
+            }
+        except (ValueError, json.JSONDecodeError):
+            logger.exception("Failed to parse AI response")
+            return {}
 
 
 def get_ai_document_classification(
index 408678f7b175a8c34d99986ad3b12ae646491b58..548acbc6c70237129569e2a0583eae66288199ce 100644 (file)
@@ -48,6 +48,26 @@ def mock_document():
     return doc
 
 
+@pytest.fixture
+def mock_similar_documents():
+    doc1 = MagicMock()
+    doc1.content = "Content of document 1"
+    doc1.title = "Title 1"
+    doc1.filename = "file1.txt"
+
+    doc2 = MagicMock()
+    doc2.content = "Content of document 2"
+    doc2.title = None
+    doc2.filename = "file2.txt"
+
+    doc3 = MagicMock()
+    doc3.content = None
+    doc3.title = None
+    doc3.filename = None
+
+    return [doc1, doc2, doc3]
+
+
 @pytest.mark.django_db
 @patch("paperless_ai.client.AIClient.run_llm_query")
 @override_settings(
@@ -76,6 +96,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
     assert result["dates"] == ["2023-01-01"]
 
 
+@pytest.mark.django_db
+@patch("paperless_ai.client.AIClient.run_llm_query")
+@override_settings(
+    LLM_BACKEND="ollama",
+    LLM_MODEL="some_model",
+)
+def test_get_ai_document_classification_fallback_parse_success(
+    mock_run_llm_query,
+    mock_document,
+):
+    mock_run_llm_query.return_value.text = """
+    There is some text before the JSON.
+    ```json
+    {
+        "title": "Test Title",
+        "tags": ["test", "document"],
+        "correspondents": ["John Doe"],
+        "document_types": ["report"],
+        "storage_paths": ["Reports"],
+        "dates": ["2023-01-01"]
+    }
+    ```
+    """
+
+    result = get_ai_document_classification(mock_document)
+
+    assert result["title"] == "Test Title"
+    assert result["tags"] == ["test", "document"]
+    assert result["correspondents"] == ["John Doe"]
+    assert result["document_types"] == ["report"]
+    assert result["storage_paths"] == ["Reports"]
+    assert result["dates"] == ["2023-01-01"]
+
+
+@pytest.mark.django_db
+@patch("paperless_ai.client.AIClient.run_llm_query")
+@override_settings(
+    LLM_BACKEND="ollama",
+    LLM_MODEL="some_model",
+)
+def test_get_ai_document_classification_parse_failure(
+    mock_run_llm_query,
+    mock_document,
+):
+    mock_run_llm_query.return_value.text = "Invalid JSON response"
+
+    result = get_ai_document_classification(mock_document)
+    assert result == {}
+
+
 @pytest.mark.django_db
 @patch("paperless_ai.client.AIClient.run_llm_query")
 def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
@@ -154,26 +224,6 @@ def test_prompt_with_without_rag(mock_document):
         assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
 
 
-@pytest.fixture
-def mock_similar_documents():
-    doc1 = MagicMock()
-    doc1.content = "Content of document 1"
-    doc1.title = "Title 1"
-    doc1.filename = "file1.txt"
-
-    doc2 = MagicMock()
-    doc2.content = "Content of document 2"
-    doc2.title = None
-    doc2.filename = "file2.txt"
-
-    doc3 = MagicMock()
-    doc3.content = None
-    doc3.title = None
-    doc3.filename = None
-
-    return [doc1, doc2, doc3]
-
-
 @patch("paperless_ai.ai_classifier.query_similar_documents")
 def test_get_context_for_document(
     mock_query_similar_documents,