]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
RAG into suggestions
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 03:51:19 +0000 (20:51 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:53 +0000 (11:01 -0700)
src/paperless/ai/ai_classifier.py
src/paperless/ai/rag.py [new file with mode: 0644]

index 949cfaf696c1488cd42ff6858ff2d97e29ee329d..d5ec88323adb23e4358b68307e6b349dfb535605 100644 (file)
@@ -3,15 +3,13 @@ import logging
 
 from documents.models import Document
 from paperless.ai.client import AIClient
+from paperless.ai.rag import get_context_for_document
+from paperless.config import AIConfig
 
-logger = logging.getLogger("paperless.ai.ai_classifier")
+logger = logging.getLogger("paperless.ai.rag_classifier")
 
 
-def get_ai_document_classification(document: Document) -> dict:
-    """
-    Returns classification suggestions for a given document using an LLM.
-    Output schema matches the API's expected DocumentClassificationSuggestions format.
-    """
+def build_prompt_without_rag(document: Document) -> str:
     filename = document.filename or ""
     content = document.content or ""
 
@@ -41,6 +39,7 @@ def get_ai_document_classification(document: Document) -> dict:
     }}
     ---
 
+
     FILENAME:
     {filename}
 
@@ -48,39 +47,71 @@ def get_ai_document_classification(document: Document) -> dict:
     {content[:8000]}  # Trim to safe size
     """
 
-    try:
-        client = AIClient()
-        result = client.run_llm_query(prompt)
-        suggestions = parse_ai_classification_response(result)
-        return suggestions or {}
-    except Exception:
-        logger.exception("Error during LLM classification: %s", exc_info=True)
-        return {}
+    return prompt
 
 
-def parse_ai_classification_response(text: str) -> dict:
-    """
-    Parses LLM output and ensures it conforms to expected schema.
+def build_prompt_with_rag(document: Document) -> str:
+    context = get_context_for_document(document)
+    content = document.content or ""
+    filename = document.filename or ""
+
+    prompt = f"""
+    You are a helpful assistant that extracts structured information from documents.
+    You have access to similar documents as context to help improve suggestions.
+
+    Only output valid JSON in the format below. No additional explanations.
+
+    The JSON object must contain:
+    - title: A short, descriptive title
+    - tags: A list of relevant topics
+    - correspondents: People or organizations involved
+    - document_types: Type or category of the document
+    - storage_paths: Suggested folder paths
+    - dates: Up to 3 relevant dates in YYYY-MM-DD
+
+    Here is an example document:
+    FILENAME:
+    {filename}
+
+    CONTENT:
+    {content[:4000]}
+
+    CONTEXT FROM SIMILAR DOCUMENTS:
+    {context[:4000]}
     """
+
+    return prompt
+
+
+def parse_ai_response(text: str) -> dict:
     try:
         raw = json.loads(text)
         return {
             "title": raw.get("title"),
             "tags": raw.get("tags", []),
-            "correspondents": [raw["correspondents"]]
-            if isinstance(raw.get("correspondents"), str)
-            else raw.get("correspondents", []),
-            "document_types": [raw["document_types"]]
-            if isinstance(raw.get("document_types"), str)
-            else raw.get("document_types", []),
+            "correspondents": raw.get("correspondents", []),
+            "document_types": raw.get("document_types", []),
             "storage_paths": raw.get("storage_paths", []),
-            "dates": [d for d in raw.get("dates", []) if d],
+            "dates": raw.get("dates", []),
         }
     except json.JSONDecodeError:
-        # fallback: try to extract JSON manually?
-        logger.exception(
-            "Failed to parse LLM classification response: %s",
-            text,
-            exc_info=True,
-        )
+        logger.exception("Invalid JSON in RAG response")
+        return {}
+
+
+def get_ai_document_classification(document: Document) -> dict:
+    ai_config = AIConfig()
+
+    prompt = (
+        build_prompt_with_rag(document)
+        if ai_config.llm_embedding_backend
+        else build_prompt_without_rag(document)
+    )
+
+    try:
+        client = AIClient()
+        result = client.run_llm_query(prompt)
+        return parse_ai_response(result)
+    except Exception:
+        logger.exception("Failed AI classification")
         return {}
diff --git a/src/paperless/ai/rag.py b/src/paperless/ai/rag.py
new file mode 100644 (file)
index 0000000..9b5baf4
--- /dev/null
@@ -0,0 +1,12 @@
+from documents.models import Document
+from paperless.ai.indexing import query_similar_documents
+
+
+def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
+    similar_docs = query_similar_documents(doc)[:max_docs]
+    context_blocks = []
+    for similar in similar_docs:
+        text = similar.content or ""
+        title = similar.title or similar.filename or "Untitled"
+        context_blocks.append(f"TITLE: {title}\n{text}")
+    return "\n\n".join(context_blocks)