]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Fix ollama, fix RAG
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 05:03:21 +0000 (22:03 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:53 +0000 (11:01 -0700)
[ci skip]

src/documents/management/commands/document_llmindex.py
src/documents/tasks.py
src/paperless/ai/ai_classifier.py
src/paperless/ai/client.py

index 2985a61e4239cc93a991f826170b18e2ef044dbe..09ea477c2210c6bfdf6cc09b19af6137dae48cbc 100644 (file)
@@ -15,5 +15,7 @@ class Command(ProgressBarMixin, BaseCommand):
     def handle(self, *args, **options):
         self.handle_progress_bar_mixin(**options)
         with transaction.atomic():
-            if options["command"] == "rebuild":
-                llm_index_rebuild(progress_bar_disable=self.no_progress_bar)
+            llm_index_rebuild(
+                progress_bar_disable=self.no_progress_bar,
+                rebuild=options["command"] == "rebuild",
+            )
index a7856897678f732b8cbef3a642c94ee0dd97c77c..c2427929a1616074d7ea191961b607001db64f44 100644 (file)
@@ -7,6 +7,7 @@ from pathlib import Path
 from tempfile import TemporaryDirectory
 
 import faiss
+import llama_index.core.settings as llama_settings
 import tqdm
 from celery import Task
 from celery import shared_task
@@ -21,7 +22,9 @@ from filelock import FileLock
 from llama_index.core import Document as LlamaDocument
 from llama_index.core import StorageContext
 from llama_index.core import VectorStoreIndex
-from llama_index.core.settings import Settings
+from llama_index.core.node_parser import SimpleNodeParser
+from llama_index.core.storage.docstore import SimpleDocumentStore
+from llama_index.core.storage.index_store import SimpleIndexStore
 from llama_index.vector_stores.faiss import FaissVectorStore
 from whoosh.writing import AsyncWriter
 
@@ -533,45 +536,56 @@ def llm_index_rebuild(*, progress_bar_disable=False, rebuild=False):
         shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
         settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
 
-    documents = Document.objects.all()
-
     embed_model = get_embedding_model()
+    llama_settings.Settings.embed_model = embed_model
 
     if rebuild or not settings.LLM_INDEX_DIR.exists():
         embedding_dim = get_embedding_dim()
         faiss_index = faiss.IndexFlatL2(embedding_dim)
-        vector_store = FaissVectorStore(faiss_index)
+        vector_store = FaissVectorStore(faiss_index=faiss_index)
     else:
         vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
-    storage_context = StorageContext.from_defaults(vector_store=vector_store)
-    Settings.embed_model = embed_model
 
-    llm_docs = []
-    for document in tqdm.tqdm(documents, disable=progress_bar_disable):
+    docstore = SimpleDocumentStore()
+    index_store = SimpleIndexStore()
+
+    storage_context = StorageContext.from_defaults(
+        docstore=docstore,
+        index_store=index_store,
+        persist_dir=settings.LLM_INDEX_DIR,
+        vector_store=vector_store,
+    )
+
+    parser = SimpleNodeParser()
+    nodes = []
+
+    for document in tqdm.tqdm(Document.objects.all(), disable=progress_bar_disable):
         if not document.content:
             continue
-        llm_docs.append(
-            LlamaDocument(
-                text=build_llm_index_text(document),
-                metadata={
-                    "id": document.id,
-                    "title": document.title,
-                    "tags": [t.name for t in document.tags.all()],
-                    "correspondent": document.correspondent.name
-                    if document.correspondent
-                    else None,
-                    "document_type": document.document_type.name
-                    if document.document_type
-                    else None,
-                    "created": document.created.isoformat(),
-                    "added": document.added.isoformat(),
-                },
-            ),
-        )
 
-    index = VectorStoreIndex.from_documents(
-        llm_docs,
+        text = build_llm_index_text(document)
+        metadata = {
+            "document_id": document.id,
+            "title": document.title,
+            "tags": [t.name for t in document.tags.all()],
+            "correspondent": document.correspondent.name
+            if document.correspondent
+            else None,
+            "document_type": document.document_type.name
+            if document.document_type
+            else None,
+            "created": document.created.isoformat() if document.created else None,
+            "added": document.added.isoformat() if document.added else None,
+        }
+
+        doc = LlamaDocument(text=text, metadata=metadata)
+        doc_nodes = parser.get_nodes_from_documents([doc])
+        nodes.extend(doc_nodes)
+
+    index = VectorStoreIndex(
+        nodes=nodes,
         storage_context=storage_context,
+        embed_model=embed_model,
     )
-    settings.LLM_INDEX_DIR.mkdir(exist_ok=True)
+
     index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
index d5ec88323adb23e4358b68307e6b349dfb535605..f52548b6258e89eebd53a1472204d2572107408e 100644 (file)
@@ -62,7 +62,7 @@ def build_prompt_with_rag(document: Document) -> str:
     Only output valid JSON in the format below. No additional explanations.
 
     The JSON object must contain:
-    - title: A short, descriptive title
+    - title: A short, descriptive title based on the content
     - tags: A list of relevant topics
     - correspondents: People or organizations involved
     - document_types: Type or category of the document
@@ -112,6 +112,6 @@ def get_ai_document_classification(document: Document) -> dict:
         client = AIClient()
         result = client.run_llm_query(prompt)
         return parse_ai_response(result)
-    except Exception:
+    except Exception as e:
         logger.exception("Failed AI classification")
-        return {}
+        raise e
index 03012844f8eff213ecee04af4b432b0313a3c458..d37468b4e3e3421d7a297f9fefbd2264e13cd6ad 100644 (file)
@@ -37,15 +37,15 @@ class AIClient:
         url = self.settings.llm_url or "http://localhost:11434"
         with httpx.Client(timeout=30.0) as client:
             response = client.post(
-                f"{url}/api/chat",
+                f"{url}/api/generate",
                 json={
                     "model": self.settings.llm_model,
-                    "messages": [{"role": "user", "content": prompt}],
+                    "prompt": prompt,
                     "stream": False,
                 },
             )
             response.raise_for_status()
-            return response.json()["message"]["content"]
+            return response.json()["response"]
 
     def _run_openai_query(self, prompt: str) -> str:
         if not self.settings.llm_api_key: