]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Better encapsulate backends, use llama_index OpenAI
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 25 Apr 2025 06:20:27 +0000 (23:20 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:54 +0000 (11:01 -0700)
src/paperless/ai/ai_classifier.py
src/paperless/ai/client.py
src/paperless/ai/llms.py [new file with mode: 0644]

index 704b894a4bc937ba385a83bd68295b135f14a324..69274da56beed2518ad33b425a0158ffae8efbcb 100644 (file)
@@ -1,6 +1,8 @@
 import json
 import logging
 
+from llama_index.core.base.llms.types import CompletionResponse
+
 from documents.models import Document
 from paperless.ai.client import AIClient
 from paperless.ai.rag import get_context_for_document
@@ -28,6 +30,8 @@ def build_prompt_without_rag(document: Document) -> str:
     - storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
     - dates: List up to 3 relevant dates in YYYY-MM-DD format
 
+    Respond ONLY in JSON.
+    Each field must be a list of plain strings.
     The format of the JSON object is as follows:
     {{
         "title": "xxxxx",
@@ -69,6 +73,18 @@ def build_prompt_with_rag(document: Document) -> str:
     - storage_paths: Suggested folder paths
     - dates: Up to 3 relevant dates in YYYY-MM-DD
 
+    Respond ONLY in JSON.
+    Each field must be a list of plain strings.
+    The format of the JSON object is as follows:
+    {{
+        "title": "xxxxx",
+        "tags": ["xxxx", "xxxx"],
+        "correspondents": ["xxxx", "xxxx"],
+        "document_types": ["xxxx", "xxxx"],
+        "storage_paths": ["xxxx", "xxxx"],
+        "dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
+    }}
+
     Here is the document:
     FILENAME:
     {filename}
@@ -83,9 +99,9 @@ def build_prompt_with_rag(document: Document) -> str:
     return prompt
 
 
-def parse_ai_response(text: str) -> dict:
+def parse_ai_response(response: CompletionResponse) -> dict:
     try:
-        raw = json.loads(text)
+        raw = json.loads(response.text)
         return {
             "title": raw.get("title"),
             "tags": raw.get("tags", []),
@@ -95,7 +111,7 @@ def parse_ai_response(text: str) -> dict:
             "dates": raw.get("dates", []),
         }
     except json.JSONDecodeError:
-        logger.exception("Invalid JSON in RAG response")
+        logger.exception("Invalid JSON in AI response")
         return {}
 
 
index 514605e91ddfd77daf9e1c3cea286c1820fd0ad8..cf3b0b0eb016200f3293d952fe3dd64ceec46349 100644 (file)
@@ -1,7 +1,9 @@
 import logging
 
-import httpx
+from llama_index.core.llms import ChatMessage
+from llama_index.llms.openai import OpenAI
 
+from paperless.ai.llms import OllamaLLM
 from paperless.config import AIConfig
 
 logger = logging.getLogger("paperless.ai.client")
@@ -12,8 +14,23 @@ class AIClient:
     A client for interacting with an LLM backend.
     """
 
+    def get_llm(self):
+        if self.settings.llm_backend == "ollama":
+            return OllamaLLM(
+                model=self.settings.llm_model or "llama3",
+                base_url=self.settings.llm_url or "http://localhost:11434",
+            )
+        elif self.settings.llm_backend == "openai":
+            return OpenAI(
+                model=self.settings.llm_model or "gpt-3.5-turbo",
+                api_key=self.settings.openai_api_key,
+            )
+        else:
+            raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
+
     def __init__(self):
         self.settings = AIConfig()
+        self.llm = self.get_llm()
 
     def run_llm_query(self, prompt: str) -> str:
         logger.debug(
@@ -21,50 +38,16 @@ class AIClient:
             self.settings.llm_backend,
             self.settings.llm_model,
         )
-        match self.settings.llm_backend:
-            case "openai":
-                result = self._run_openai_query(prompt)
-            case "ollama":
-                result = self._run_ollama_query(prompt)
-            case _:
-                raise ValueError(
-                    f"Unsupported LLM backend: {self.settings.llm_backend}",
-                )
+        result = self.llm.complete(prompt)
         logger.debug("LLM query result: %s", result)
         return result
 
-    def _run_ollama_query(self, prompt: str) -> str:
-        url = self.settings.llm_url or "http://localhost:11434"
-        with httpx.Client(timeout=60.0) as client:
-            response = client.post(
-                f"{url}/api/generate",
-                json={
-                    "model": self.settings.llm_model,
-                    "prompt": prompt,
-                    "stream": False,
-                },
-            )
-            response.raise_for_status()
-            return response.json()["response"]
-
-    def _run_openai_query(self, prompt: str) -> str:
-        if not self.settings.llm_api_key:
-            raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
-
-        url = self.settings.llm_url or "https://api.openai.com"
-
-        with httpx.Client(timeout=30.0) as client:
-            response = client.post(
-                f"{url}/v1/chat/completions",
-                headers={
-                    "Authorization": f"Bearer {self.settings.llm_api_key}",
-                    "Content-Type": "application/json",
-                },
-                json={
-                    "model": self.settings.llm_model,
-                    "messages": [{"role": "user", "content": prompt}],
-                    "temperature": 0.3,
-                },
-            )
-            response.raise_for_status()
-            return response.json()["choices"][0]["message"]["content"]
+    def run_chat(self, messages: list[ChatMessage]) -> str:
+        logger.debug(
+            "Running chat query against %s with model %s",
+            self.settings.llm_backend,
+            self.settings.llm_model,
+        )
+        result = self.llm.chat(messages)
+        logger.debug("Chat result: %s", result)
+        return result
diff --git a/src/paperless/ai/llms.py b/src/paperless/ai/llms.py
new file mode 100644 (file)
index 0000000..b51045d
--- /dev/null
@@ -0,0 +1,64 @@
+import httpx
+from llama_index.core.base.llms.types import ChatMessage
+from llama_index.core.base.llms.types import ChatResponse
+from llama_index.core.base.llms.types import ChatResponseGen
+from llama_index.core.base.llms.types import CompletionResponse
+from llama_index.core.base.llms.types import CompletionResponseGen
+from llama_index.core.base.llms.types import LLMMetadata
+from llama_index.core.llms.llm import LLM
+from pydantic import Field
+
+
+class OllamaLLM(LLM):
+    model: str = Field(default="llama3")
+    base_url: str = Field(default="http://localhost:11434")
+
+    @property
+    def metadata(self) -> LLMMetadata:
+        return LLMMetadata(
+            model_name=self.model,
+            is_chat_model=False,
+            context_window=4096,
+            num_output=512,
+            is_function_calling_model=False,
+        )
+
+    def complete(self, prompt: str, **kwargs) -> CompletionResponse:
+        with httpx.Client(timeout=120.0) as client:
+            response = client.post(
+                f"{self.base_url}/api/generate",
+                json={
+                    "model": self.model,
+                    "prompt": prompt,
+                    "stream": False,
+                },
+            )
+            response.raise_for_status()
+            data = response.json()
+            return CompletionResponse(text=data["response"])
+
+    # -- Required stubs for ABC:
+    def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+        raise NotImplementedError("stream_complete not supported")
+
+    def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+        raise NotImplementedError("chat not supported")
+
+    def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
+        raise NotImplementedError("stream_chat not supported")
+
+    async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
+        raise NotImplementedError("async chat not supported")
+
+    async def astream_chat(
+        self,
+        messages: list[ChatMessage],
+        **kwargs,
+    ) -> ChatResponseGen:
+        raise NotImplementedError("async stream_chat not supported")
+
+    async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
+        raise NotImplementedError("async complete not supported")
+
+    async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
+        raise NotImplementedError("async stream_complete not supported")