]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Better respect perms for ai suggestions
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 29 Apr 2025 05:12:41 +0000 (22:12 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:04:03 +0000 (11:04 -0700)
src/documents/views.py
src/paperless/ai/ai_classifier.py
src/paperless/ai/indexing.py

index 3a392b98d3f58e9efce556933e0073450fdcbd03..8afd9b383cd2d4e959eb6d4ddbb6df9c107cbfb4 100644 (file)
@@ -789,7 +789,7 @@ class DocumentViewSet(
                 refresh_suggestions_cache(doc.pk)
                 return Response(cached_llm_suggestions.suggestions)
 
-            llm_suggestions = get_ai_document_classification(doc)
+            llm_suggestions = get_ai_document_classification(doc, request.user)
 
             matched_tags = match_tags_by_name(
                 llm_suggestions.get("tags", []),
index 6b34c089994398ea58ee11290c76e7ad0d366155..33101718de1e6aee6c45958e42adf5437e7e45d1 100644 (file)
@@ -1,9 +1,11 @@
 import json
 import logging
 
+from django.contrib.auth.models import User
 from llama_index.core.base.llms.types import CompletionResponse
 
 from documents.models import Document
+from documents.permissions import get_objects_for_user_owner_aware
 from paperless.ai.client import AIClient
 from paperless.ai.indexing import query_similar_documents
 from paperless.config import AIConfig
@@ -52,8 +54,8 @@ def build_prompt_without_rag(document: Document) -> str:
     return prompt
 
 
-def build_prompt_with_rag(document: Document) -> str:
-    context = get_context_for_document(document)
+def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
+    context = get_context_for_document(document, user)
     prompt = build_prompt_without_rag(document)
 
     prompt += f"""
@@ -65,8 +67,26 @@ def build_prompt_with_rag(document: Document) -> str:
     return prompt
 
 
-def get_context_for_document(doc: Document, max_docs: int = 5) -> str:
-    similar_docs = query_similar_documents(doc)[:max_docs]
+def get_context_for_document(
+    doc: Document,
+    user: User | None = None,
+    max_docs: int = 5,
+) -> str:
+    visible_documents = (
+        get_objects_for_user_owner_aware(
+            user,
+            "view_document",
+            Document,
+        )
+        if user
+        else None
+    )
+    similar_docs = query_similar_documents(
+        document=doc,
+        document_ids=[document.pk for document in visible_documents]
+        if visible_documents
+        else None,
+    )[:max_docs]
     context_blocks = []
     for similar in similar_docs:
         text = similar.content or ""
@@ -91,11 +111,14 @@ def parse_ai_response(response: CompletionResponse) -> dict:
         return {}
 
 
-def get_ai_document_classification(document: Document) -> dict:
+def get_ai_document_classification(
+    document: Document,
+    user: User | None = None,
+) -> dict:
     ai_config = AIConfig()
 
     prompt = (
-        build_prompt_with_rag(document)
+        build_prompt_with_rag(document, user)
         if ai_config.llm_embedding_backend
         else build_prompt_without_rag(document)
     )
index 9a32409cac1207cc48c4336daa3a0a22a4487ac8..3e354ba6de52c481eddf2e543248434616b8cb25 100644 (file)
@@ -206,12 +206,32 @@ def llm_index_remove_document(document: Document):
     index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
 
 
-def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
+def query_similar_documents(
+    document: Document,
+    top_k: int = 5,
+    document_ids: list[int] | None = None,
+) -> list[Document]:
     """
     Runs a similarity query and returns top-k similar Document objects.
     """
     index = load_or_build_index()
-    retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
+
+    # constrain only the node(s) that match the document IDs, if given
+    doc_node_ids = (
+        [
+            node.node_id
+            for node in index.docstore.docs.values()
+            if node.metadata.get("document_id") in document_ids
+        ]
+        if document_ids
+        else None
+    )
+
+    retriever = VectorIndexRetriever(
+        index=index,
+        similarity_top_k=top_k,
+        doc_ids=doc_node_ids,
+    )
 
     query_text = (document.title or "") + "\n" + (document.content or "")
     results = retriever.retrieve(query_text)