import logging
from django.contrib.auth.models import User
-from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
- prompt = f"""
- You are an assistant that extracts structured information from documents.
- Only respond with the JSON object as described below.
- Never ask for further information, additional content or ask questions. Never include any other text.
- Suggested tags and document types must be strictly based on the content of the document.
- Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
- Each field must be a list of plain strings.
-
- The JSON object must contain the following fields:
- - title: A short, descriptive title
- - tags: A list of simple tags like ["insurance", "medical", "receipts"]
- - correspondents: A list of names or organizations mentioned in the document
- - document_types: The type/category of the document (e.g. "invoice", "medical record")
- - storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- - dates: List up to 3 relevant dates in YYYY-MM-DD format
-
- The format of the JSON object is as follows:
- {{
- "title": "xxxxx",
- "tags": ["xxxx", "xxxx"],
- "correspondents": ["xxxx", "xxxx"],
- "document_types": ["xxxx", "xxxx"],
- "storage_paths": ["xxxx", "xxxx"],
- "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
- }}
- ---------
-
- FILENAME:
+ return f"""
+ You are a document classification assistant.
+
+ Analyze the following document and extract the following information:
+ - A short descriptive title
+ - Tags that reflect the content
+ - Names of people or organizations mentioned
+ - The type or category of the document
+ - Suggested folder paths for storing the document
+ - Up to 3 relevant dates in YYYY-MM-DD format
+
+ Filename:
{filename}
- CONTENT:
+ Content:
{content}
- """
-
- return prompt
+ """.strip()
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
+ base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
- prompt = build_prompt_without_rag(document)
- prompt += f"""
+ return f"""{base_prompt}
- CONTEXT FROM SIMILAR DOCUMENTS:
+ Additional context from similar documents:
{context}
-
- ---------
-
- DO NOT RESPOND WITH ANYTHING OTHER THAN THE JSON OBJECT.
- """
-
- return prompt
+ """.strip()
def get_context_for_document(
return "\n\n".join(context_blocks)
-def parse_ai_response(response: CompletionResponse) -> dict:
+def parse_ai_response(raw: dict) -> dict:
try:
- raw = json.loads(response.text)
return {
- "title": raw.get("title"),
+ "title": raw.get("title", ""),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
- except json.JSONDecodeError:
- logger.warning("Invalid JSON in AI response, attempting modified parsing...")
- try:
- # search for a valid json string like { ... } in the response
- start = response.text.index("{")
- end = response.text.rindex("}") + 1
- json_str = response.text[start:end]
- raw = json.loads(json_str)
- return {
- "title": raw.get("title"),
- "tags": raw.get("tags", []),
- "correspondents": raw.get("correspondents", []),
- "document_types": raw.get("document_types", []),
- "storage_paths": raw.get("storage_paths", []),
- "dates": raw.get("dates", []),
- }
- except (ValueError, json.JSONDecodeError):
- logger.exception("Failed to parse AI response")
- return {}
+ except (ValueError, json.JSONDecodeError):
+ logger.exception("Failed to parse AI response")
+ return {}
def get_ai_document_classification(
import logging
from llama_index.core.llms import ChatMessage
+from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
+from paperless_ai.tools import DocumentClassifierSchema
logger = logging.getLogger("paperless_ai.client")
self.settings = AIConfig()
self.llm = self.get_llm()
- def get_llm(self):
+ def get_llm(self) -> Ollama | OpenAI:
if self.settings.llm_backend == "ollama":
return Ollama(
model=self.settings.llm_model or "llama3",
self.settings.llm_backend,
self.settings.llm_model,
)
- result = self.llm.complete(prompt)
- logger.debug("LLM query result: %s", result)
- return result
+
+ user_msg = ChatMessage(role="user", content=prompt)
+ tool = get_function_tool(DocumentClassifierSchema)
+ result = self.llm.chat_with_tools(
+ tools=[tool],
+ user_msg=user_msg,
+ chat_history=[],
+ )
+ tool_calls = self.llm.get_tool_calls_from_response(
+ result,
+ error_on_no_tool_calls=True,
+ )
+ logger.debug("LLM query result: %s", tool_calls)
+ parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
+ return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(