]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Unify, respect perms
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 07:09:33 +0000 (00:09 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:56 +0000 (11:01 -0700)
[ci skip]

src/documents/views.py
src/paperless/ai/chat.py
src/paperless/ai/indexing.py

index a645709090e47b9b3fa1c4413002b1164896c094..d6f15dde32403d0a790e2aadf71e2e3a120fe33d 100644 (file)
@@ -174,7 +174,6 @@ from documents.utils import get_boolean
 from paperless import version
 from paperless.ai.ai_classifier import get_ai_document_classification
 from paperless.ai.chat import chat_with_documents
-from paperless.ai.chat import chat_with_single_document
 from paperless.ai.matching import extract_unmatched_names
 from paperless.ai.matching import match_correspondents_by_name
 from paperless.ai.matching import match_document_types_by_name
@@ -1178,13 +1177,23 @@ class DocumentViewSet(
         question = request.data["q"]
         doc_id = request.data.get("document_id", None)
         if doc_id:
-            document = Document.objects.get(id=doc_id)
+            try:
+                document = Document.objects.get(id=doc_id)
+            except Document.DoesNotExist:
+                return HttpResponseBadRequest("Invalid document ID")
+
             if not has_perms_owner_aware(request.user, "view_document", document):
                 return HttpResponseForbidden("Insufficient permissions")
 
-            result = chat_with_single_document(document, question, request.user)
+            documents = [document]
         else:
-            result = chat_with_documents(question, request.user)
+            documents = get_objects_for_user_owner_aware(
+                request.user,
+                "view_document",
+                Document,
+            )
+
+        result = chat_with_documents(question, documents)
 
         return Response({"answer": result})
 
index 6e75884d99b0e56c3c29323c7a9a42afc52fc305..3ce109a79455b1d070ecfdaa4a90319ce2e0dfb8 100644 (file)
@@ -1,52 +1,44 @@
 import logging
 
-from django.contrib.auth.models import User
 from llama_index.core import VectorStoreIndex
 from llama_index.core.query_engine import RetrieverQueryEngine
 
+from documents.models import Document
 from paperless.ai.client import AIClient
-from paperless.ai.indexing import get_document_retriever
 from paperless.ai.indexing import load_index
 
 logger = logging.getLogger("paperless.ai.chat")
 
 
-def chat_with_documents(prompt: str, user: User) -> str:
-    retriever = get_document_retriever(top_k=5)
+def chat_with_documents(prompt: str, documents: list[Document]) -> str:
     client = AIClient()
 
-    query_engine = RetrieverQueryEngine.from_args(
-        retriever=retriever,
-        llm=client.llm,
-    )
-
-    logger.debug("Document chat prompt: %s", prompt)
-    response = query_engine.query(prompt)
-    logger.debug("Document chat response: %s", response)
-    return str(response)
-
-
-def chat_with_single_document(document, question: str, user):
     index = load_index()
 
-    # Filter only the node(s) belonging to this doc
+    doc_ids = [doc.pk for doc in documents]
+
+    # Filter only the node(s) that match the document IDs
     nodes = [
         node
         for node in index.docstore.docs.values()
-        if node.metadata.get("document_id") == str(document.id)
+        if node.metadata.get("document_id") in doc_ids
     ]
 
-    if not nodes:
-        raise Exception("This document is not indexed yet.")
+    if len(nodes) == 0:
+        logger.warning("No nodes found for the given documents.")
+        return "Sorry, I couldn't find any content to answer your question."
 
     local_index = VectorStoreIndex.from_documents(nodes)
+    retriever = local_index.as_retriever(
+        similarity_top_k=3 if len(documents) == 1 else 5,
+    )
 
-    client = AIClient()
-
-    engine = RetrieverQueryEngine.from_args(
-        retriever=local_index.as_retriever(similarity_top_k=3),
+    query_engine = RetrieverQueryEngine.from_args(
+        retriever=retriever,
         llm=client.llm,
     )
 
-    response = engine.query(question)
+    logger.debug("Document chat prompt: %s", prompt)
+    response = query_engine.query(prompt)
+    logger.debug("Document chat response: %s", response)
     return str(response)
index 9ed09daa1648c0948a289d4d520789417599021c..6d9a59e792be64cb2317072f11036fe2fce82213 100644 (file)
@@ -14,11 +14,6 @@ from paperless.ai.embedding import get_embedding_model
 logger = logging.getLogger("paperless.ai.indexing")
 
 
-def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
-    index = load_index()
-    return VectorIndexRetriever(index=index, similarity_top_k=top_k)
-
-
 def load_index() -> VectorStoreIndex:
     """Loads the persisted LlamaIndex from disk."""
     vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
@@ -37,7 +32,8 @@ def load_index() -> VectorStoreIndex:
 def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
     """Runs a similarity query and returns top-k similar Document objects."""
     # Load the index
-    retriever = get_document_retriever(top_k=top_k)
+    index = load_index()
+    retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
 
     # Build query from the document text
     query_text = (document.title or "") + "\n" + (document.content or "")