from documents.models import Document
from paperless.ai.client import AIClient
+from paperless.ai.rag import get_context_for_document
+from paperless.config import AIConfig
-logger = logging.getLogger("paperless.ai.ai_classifier")
+logger = logging.getLogger("paperless.ai.rag_classifier")
-def get_ai_document_classification(document: Document) -> dict:
- """
- Returns classification suggestions for a given document using an LLM.
- Output schema matches the API's expected DocumentClassificationSuggestions format.
- """
+def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or ""
content = document.content or ""
}}
---
+
FILENAME:
{filename}
{content[:8000]} # Trim to safe size
"""
- try:
- client = AIClient()
- result = client.run_llm_query(prompt)
- suggestions = parse_ai_classification_response(result)
- return suggestions or {}
- except Exception:
- logger.exception("Error during LLM classification: %s", exc_info=True)
- return {}
+ return prompt
-def parse_ai_classification_response(text: str) -> dict:
- """
- Parses LLM output and ensures it conforms to expected schema.
+def build_prompt_with_rag(document: Document) -> str:
+ context = get_context_for_document(document)
+ content = document.content or ""
+ filename = document.filename or ""
+
+ prompt = f"""
+ You are a helpful assistant that extracts structured information from documents.
+ You have access to similar documents as context to help improve suggestions.
+
+ Only output valid JSON in the format below. No additional explanations.
+
+ The JSON object must contain:
+ - title: A short, descriptive title
+ - tags: A list of relevant topics
+ - correspondents: People or organizations involved
+ - document_types: Type or category of the document
+ - storage_paths: Suggested folder paths
+ - dates: Up to 3 relevant dates in YYYY-MM-DD
+
+ Here is an example document:
+ FILENAME:
+ {filename}
+
+ CONTENT:
+ {content[:4000]}
+
+ CONTEXT FROM SIMILAR DOCUMENTS:
+ {context[:4000]}
"""
+
+ return prompt
+
+
+def parse_ai_response(text: str) -> dict:
try:
raw = json.loads(text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
- "correspondents": [raw["correspondents"]]
- if isinstance(raw.get("correspondents"), str)
- else raw.get("correspondents", []),
- "document_types": [raw["document_types"]]
- if isinstance(raw.get("document_types"), str)
- else raw.get("document_types", []),
+ "correspondents": raw.get("correspondents", []),
+ "document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
- "dates": [d for d in raw.get("dates", []) if d],
+ "dates": raw.get("dates", []),
}
except json.JSONDecodeError:
- # fallback: try to extract JSON manually?
- logger.exception(
- "Failed to parse LLM classification response: %s",
- text,
- exc_info=True,
- )
+ logger.exception("Invalid JSON in RAG response")
+ return {}
+
+
+def get_ai_document_classification(document: Document) -> dict:
+ ai_config = AIConfig()
+
+ prompt = (
+ build_prompt_with_rag(document)
+ if ai_config.llm_embedding_backend
+ else build_prompt_without_rag(document)
+ )
+
+ try:
+ client = AIClient()
+ result = client.run_llm_query(prompt)
+ return parse_ai_response(result)
+ except Exception:
+ logger.exception("Failed AI classification")
return {}