]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Super basic doc chat
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 06:41:31 +0000 (23:41 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:55 +0000 (11:01 -0700)
[ci skip]

src/documents/views.py
src/paperless/ai/chat.py [new file with mode: 0644]
src/paperless/ai/client.py
src/paperless/ai/indexing.py

index dad9f560eb5042331aca142a271ecda42b93224a..cbac570b232a136bef0a0c13458e0473c3115993 100644 (file)
@@ -173,6 +173,7 @@ from documents.templating.filepath import validate_filepath_template_and_render
 from documents.utils import get_boolean
 from paperless import version
 from paperless.ai.ai_classifier import get_ai_document_classification
+from paperless.ai.chat import chat_with_documents
 from paperless.ai.matching import extract_unmatched_names
 from paperless.ai.matching import match_correspondents_by_name
 from paperless.ai.matching import match_document_types_by_name
@@ -1167,6 +1168,17 @@ class DocumentViewSet(
                 "Error emailing document, check logs for more detail.",
             )
 
+    @action(methods=["post"], detail=False, url_path="chat")
+    def chat(self, request):
+        ai_config = AIConfig()
+        if not ai_config.ai_enabled:
+            return HttpResponseBadRequest("AI is required for this feature")
+
+        question = request.data["q"]
+        result = chat_with_documents(question, request.user)
+
+        return Response({"answer": result})
+
 
 @extend_schema_view(
     list=extend_schema(
diff --git a/src/paperless/ai/chat.py b/src/paperless/ai/chat.py
new file mode 100644 (file)
index 0000000..eb485b6
--- /dev/null
@@ -0,0 +1,24 @@
+import logging
+
+from django.contrib.auth.models import User
+from llama_index.core.query_engine import RetrieverQueryEngine
+
+from paperless.ai.client import AIClient
+from paperless.ai.indexing import get_document_retriever
+
+logger = logging.getLogger("paperless.ai.chat")
+
+
+def chat_with_documents(prompt: str, user: User) -> str:
+    retriever = get_document_retriever(top_k=5)
+    client = AIClient()
+
+    query_engine = RetrieverQueryEngine.from_args(
+        retriever=retriever,
+        llm=client.llm,
+    )
+
+    logger.debug("Document chat prompt: %s", prompt)
+    response = query_engine.query(prompt)
+    logger.debug("Document chat response: %s", response)
+    return str(response)
index cf3b0b0eb016200f3293d952fe3dd64ceec46349..2ebb2b48d64652c9ae68c4a0342255e73ed1d770 100644 (file)
@@ -14,6 +14,10 @@ class AIClient:
     A client for interacting with an LLM backend.
     """
 
+    def __init__(self):
+        self.settings = AIConfig()
+        self.llm = self.get_llm()
+
     def get_llm(self):
         if self.settings.llm_backend == "ollama":
             return OllamaLLM(
@@ -28,10 +32,6 @@ class AIClient:
         else:
             raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
 
-    def __init__(self):
-        self.settings = AIConfig()
-        self.llm = self.get_llm()
-
     def run_llm_query(self, prompt: str) -> str:
         logger.debug(
             "Running LLM query against %s with model %s",
index 271b5f3cd3cabae9ed57e8ca7ba712748b416bb2..9ed09daa1648c0948a289d4d520789417599021c 100644 (file)
@@ -14,6 +14,11 @@ from paperless.ai.embedding import get_embedding_model
 logger = logging.getLogger("paperless.ai.indexing")
 
 
+def get_document_retriever(top_k: int = 5) -> VectorIndexRetriever:
+    index = load_index()
+    return VectorIndexRetriever(index=index, similarity_top_k=top_k)
+
+
 def load_index() -> VectorStoreIndex:
     """Loads the persisted LlamaIndex from disk."""
     vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
@@ -31,10 +36,8 @@ def load_index() -> VectorStoreIndex:
 
 def query_similar_documents(document: Document, top_k: int = 5) -> list[Document]:
     """Runs a similarity query and returns top-k similar Document objects."""
-
-    # Load index
-    index = load_index()
-    retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
+    # Load the index
+    retriever = get_document_retriever(top_k=top_k)
 
     # Build query from the document text
     query_text = (document.title or "") + "\n" + (document.content or "")