]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Move to structured output
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 15 Jul 2025 21:27:29 +0000 (14:27 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Tue, 15 Jul 2025 21:27:29 +0000 (14:27 -0700)
src/paperless_ai/ai_classifier.py
src/paperless_ai/client.py
src/paperless_ai/tools.py [new file with mode: 0644]

index 55c7c77046a3841082f558ee2dc1f6617f54508c..3b251da2b1802e7d1cfd6e2235043c229c1cc4d0 100644 (file)
@@ -2,7 +2,6 @@ import json
 import logging
 
 from django.contrib.auth.models import User
-from llama_index.core.base.llms.types import CompletionResponse
 
 from documents.models import Document
 from documents.permissions import get_objects_for_user_owner_aware
@@ -18,58 +17,34 @@ def build_prompt_without_rag(document: Document) -> str:
     filename = document.filename or ""
     content = truncate_content(document.content[:4000] or "")
 
-    prompt = f"""
-    You are an assistant that extracts structured information from documents.
-    Only respond with the JSON object as described below.
-    Never ask for further information, additional content or ask questions. Never include any other text.
-    Suggested tags and document types must be strictly based on the content of the document.
-    Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
-    Each field must be a list of plain strings.
-
-    The JSON object must contain the following fields:
-    - title: A short, descriptive title
-    - tags: A list of simple tags like ["insurance", "medical", "receipts"]
-    - correspondents: A list of names or organizations mentioned in the document
-    - document_types: The type/category of the document (e.g. "invoice", "medical record")
-    - storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
-    - dates: List up to 3 relevant dates in YYYY-MM-DD format
-
-    The format of the JSON object is as follows:
-    {{
-        "title": "xxxxx",
-        "tags": ["xxxx", "xxxx"],
-        "correspondents": ["xxxx", "xxxx"],
-        "document_types": ["xxxx", "xxxx"],
-        "storage_paths": ["xxxx", "xxxx"],
-        "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
-    }}
-    ---------
-
-    FILENAME:
+    return f"""
+    You are a document classification assistant.
+
+    Analyze the following document and extract the following information:
+    - A short descriptive title
+    - Tags that reflect the content
+    - Names of people or organizations mentioned
+    - The type or category of the document
+    - Suggested folder paths for storing the document
+    - Up to 3 relevant dates in YYYY-MM-DD format
+
+    Filename:
     {filename}
 
-    CONTENT:
+    Content:
     {content}
-    """
-
-    return prompt
+    """.strip()
 
 
 def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
+    base_prompt = build_prompt_without_rag(document)
     context = truncate_content(get_context_for_document(document, user))
-    prompt = build_prompt_without_rag(document)
 
-    prompt += f"""
+    return f"""{base_prompt}
 
-    CONTEXT FROM SIMILAR DOCUMENTS:
+    Additional context from similar documents:
     {context}
-
-    ---------
-
-    DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
-    """
-
-    return prompt
+    """.strip()
 
 
 def get_context_for_document(
@@ -100,36 +75,19 @@ def get_context_for_document(
     return "\n\n".join(context_blocks)
 
 
-def parse_ai_response(response: CompletionResponse) -> dict:
+def parse_ai_response(raw: dict) -> dict:
     try:
-        raw = json.loads(response.text)
         return {
-            "title": raw.get("title"),
+            "title": raw.get("title", ""),
             "tags": raw.get("tags", []),
             "correspondents": raw.get("correspondents", []),
             "document_types": raw.get("document_types", []),
             "storage_paths": raw.get("storage_paths", []),
             "dates": raw.get("dates", []),
         }
-    except json.JSONDecodeError:
-        logger.warning("Invalid JSON in AI response, attempting modified parsing...")
-        try:
-            # search for a valid json string like { ... } in the response
-            start = response.text.index("{")
-            end = response.text.rindex("}") + 1
-            json_str = response.text[start:end]
-            raw = json.loads(json_str)
-            return {
-                "title": raw.get("title"),
-                "tags": raw.get("tags", []),
-                "correspondents": raw.get("correspondents", []),
-                "document_types": raw.get("document_types", []),
-                "storage_paths": raw.get("storage_paths", []),
-                "dates": raw.get("dates", []),
-            }
-        except (ValueError, json.JSONDecodeError):
-            logger.exception("Failed to parse AI response")
-            return {}
+    except (ValueError, json.JSONDecodeError):
+        logger.exception("Failed to parse AI response")
+        return {}
 
 
 def get_ai_document_classification(
index 67023cfb5053e5b4848d071a1f2c4869c1b819fe..651ca70229a261b2d06228e238f4f9e84aae392d 100644 (file)
@@ -1,10 +1,12 @@
 import logging
 
 from llama_index.core.llms import ChatMessage
+from llama_index.core.program.function_program import get_function_tool
 from llama_index.llms.ollama import Ollama
 from llama_index.llms.openai import OpenAI
 
 from paperless.config import AIConfig
+from paperless_ai.tools import DocumentClassifierSchema
 
 logger = logging.getLogger("paperless_ai.client")
 
@@ -18,7 +20,7 @@ class AIClient:
         self.settings = AIConfig()
         self.llm = self.get_llm()
 
-    def get_llm(self):
+    def get_llm(self) -> Ollama | OpenAI:
         if self.settings.llm_backend == "ollama":
             return Ollama(
                 model=self.settings.llm_model or "llama3",
@@ -39,9 +41,21 @@ class AIClient:
             self.settings.llm_backend,
             self.settings.llm_model,
         )
-        result = self.llm.complete(prompt)
-        logger.debug("LLM query result: %s", result)
-        return result
+
+        user_msg = ChatMessage(role="user", content=prompt)
+        tool = get_function_tool(DocumentClassifierSchema)
+        result = self.llm.chat_with_tools(
+            tools=[tool],
+            user_msg=user_msg,
+            chat_history=[],
+        )
+        tool_calls = self.llm.get_tool_calls_from_response(
+            result,
+            error_on_no_tool_calls=True,
+        )
+        logger.debug("LLM query result: %s", tool_calls)
+        parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
+        return parsed.model_dump()
 
     def run_chat(self, messages: list[ChatMessage]) -> str:
         logger.debug(
diff --git a/src/paperless_ai/tools.py b/src/paperless_ai/tools.py
new file mode 100644 (file)
index 0000000..2924f2c
--- /dev/null
@@ -0,0 +1,10 @@
+from llama_index.core.bridge.pydantic import BaseModel
+
+
+class DocumentClassifierSchema(BaseModel):
+    title: str
+    tags: list[str]
+    correspondents: list[str]
+    document_types: list[str]
+    storage_paths: list[str]
+    dates: list[str]