]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Basic start
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sun, 20 Apr 2025 02:51:30 +0000 (19:51 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:00:54 +0000 (11:00 -0700)
src/documents/ai/__init__.py [new file with mode: 0644]
src/documents/ai/client.py [new file with mode: 0644]
src/documents/ai/llm_classifier.py [new file with mode: 0644]
src/documents/ai/matching.py [new file with mode: 0644]
src/documents/caching.py
src/documents/views.py
src/paperless/settings.py

diff --git a/src/documents/ai/__init__.py b/src/documents/ai/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/documents/ai/client.py b/src/documents/ai/client.py
new file mode 100644 (file)
index 0000000..588b45b
--- /dev/null
@@ -0,0 +1,43 @@
+import httpx
+from django.conf import settings
+
+
+def run_llm_query(prompt: str) -> str:
+    if settings.LLM_BACKEND == "ollama":
+        return _run_ollama_query(prompt)
+    return _run_openai_query(prompt)
+
+
+def _run_ollama_query(prompt: str) -> str:
+    with httpx.Client(timeout=30.0) as client:
+        response = client.post(
+            f"{settings.OLLAMA_URL}/api/chat",
+            json={
+                "model": settings.LLM_MODEL,
+                "messages": [{"role": "user", "content": prompt}],
+                "stream": False,
+            },
+        )
+        response.raise_for_status()
+        return response.json()["message"]["content"]
+
+
+def _run_openai_query(prompt: str) -> str:
+    if not settings.LLM_API_KEY:
+        raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
+
+    with httpx.Client(timeout=30.0) as client:
+        response = client.post(
+            f"{settings.OPENAI_URL}/v1/chat/completions",
+            headers={
+                "Authorization": f"Bearer {settings.LLM_API_KEY}",
+                "Content-Type": "application/json",
+            },
+            json={
+                "model": settings.LLM_MODEL,
+                "messages": [{"role": "user", "content": prompt}],
+                "temperature": 0.3,
+            },
+        )
+        response.raise_for_status()
+        return response.json()["choices"][0]["message"]["content"]
diff --git a/src/documents/ai/llm_classifier.py b/src/documents/ai/llm_classifier.py
new file mode 100644 (file)
index 0000000..a6809c1
--- /dev/null
@@ -0,0 +1,64 @@
+import json
+import logging
+
+from documents.ai.client import run_llm_query
+from documents.models import Document
+
+logger = logging.getLogger("paperless.ai.llm_classifier")
+
+
+def get_ai_document_classification(document: Document) -> dict:
+    """
+    Returns classification suggestions for a given document using an LLM.
+    Output schema matches the API's expected DocumentClassificationSuggestions format.
+    """
+    filename = document.filename or ""
+    content = document.content or ""
+
+    prompt = f"""
+    You are a document classification assistant. Based on the content below, return a JSON object suggesting the following classification fields:
+    - title: A descriptive title for the document
+    - tags: A list of tags that describe the document (e.g. ["medical", "insurance"])
+    - correspondent: Who sent or issued this document (e.g. "Kaiser Permanente")
+    - document_types: The type or category (e.g. "invoice", "medical record", "statement")
+    - storage_paths: Suggested storage folders (e.g. "Insurance/2024")
+    - dates: Up to 3 dates in ISO format (YYYY-MM-DD) found in the document, relevant to its content
+
+    Return only a valid JSON object. Do not add commentary.
+
+    FILENAME: {filename}
+
+    CONTENT:
+    {content}
+    """
+
+    try:
+        result = run_llm_query(prompt)
+        suggestions = parse_llm_classification_response(result)
+        return suggestions
+    except Exception as e:
+        logger.error(f"Error during LLM classification: {e}")
+        return None
+
+
+def parse_llm_classification_response(text: str) -> dict:
+    """
+    Parses LLM output and ensures it conforms to expected schema.
+    """
+    try:
+        raw = json.loads(text)
+        return {
+            "title": raw.get("title"),
+            "tags": raw.get("tags", []),
+            "correspondents": [raw["correspondents"]]
+            if isinstance(raw.get("correspondents"), str)
+            else raw.get("correspondents", []),
+            "document_types": [raw["document_types"]]
+            if isinstance(raw.get("document_types"), str)
+            else raw.get("document_types", []),
+            "storage_paths": raw.get("storage_paths", []),
+            "dates": [d for d in raw.get("dates", []) if d],
+        }
+    except json.JSONDecodeError:
+        # fallback: try to extract JSON manually?
+        return {}
diff --git a/src/documents/ai/matching.py b/src/documents/ai/matching.py
new file mode 100644 (file)
index 0000000..900fb8a
--- /dev/null
@@ -0,0 +1,82 @@
+import difflib
+import logging
+import re
+
+from documents.models import Correspondent
+from documents.models import DocumentType
+from documents.models import StoragePath
+from documents.models import Tag
+
+MATCH_THRESHOLD = 0.7
+
+logger = logging.getLogger("paperless.ai.matching")
+
+
+def match_tags_by_name(names: list[str], user) -> list[Tag]:
+    queryset = (
+        Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
+    )
+    return _match_names_to_queryset(names, queryset, "name")
+
+
+def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
+    queryset = (
+        Correspondent.objects.filter(owner=user)
+        if user.is_authenticated
+        else Correspondent.objects.all()
+    )
+    return _match_names_to_queryset(names, queryset, "name")
+
+
+def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
+    return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
+
+
+def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
+    queryset = (
+        StoragePath.objects.filter(owner=user)
+        if user.is_authenticated
+        else StoragePath.objects.all()
+    )
+    return _match_names_to_queryset(names, queryset, "name")
+
+
+def _normalize(s: str) -> str:
+    s = s.lower()
+    s = re.sub(r"[^\w\s]", "", s)  # remove punctuation
+    s = s.strip()
+    return s
+
+
+def _match_names_to_queryset(names: list[str], queryset, attr: str):
+    results = []
+    objects = list(queryset)
+    object_names = [getattr(obj, attr) for obj in objects]
+    norm_names = [_normalize(name) for name in object_names]
+
+    for name in names:
+        if not name:
+            continue
+        target = _normalize(name)
+
+        # First try exact match
+        if target in norm_names:
+            index = norm_names.index(target)
+            results.append(objects[index])
+            continue
+
+        # Fuzzy match fallback
+        matches = difflib.get_close_matches(
+            target,
+            norm_names,
+            n=1,
+            cutoff=MATCH_THRESHOLD,
+        )
+        if matches:
+            index = norm_names.index(matches[0])
+            results.append(objects[index])
+        else:
+            # Optional: log or store unmatched name
+            logging.debug(f"No match for: '{name}' in {attr} list")
+
+    return results
index 1099a7a73d976176d58e54a85b5f16c38cae86e6..bde21fd9267e3e04593494341362a9e2764ab81d 100644 (file)
@@ -115,6 +115,43 @@ def refresh_suggestions_cache(
     cache.touch(doc_key, timeout)
 
 
+def get_llm_suggestion_cache(
+    document_id: int,
+    backend: str,
+) -> SuggestionCacheData | None:
+    doc_key = get_suggestion_cache_key(document_id)
+    data: SuggestionCacheData = cache.get(doc_key)
+
+    if data and data.classifier_version == 1000 and data.classifier_hash == backend:
+        return data
+
+    return None
+
+
+def set_llm_suggestions_cache(
+    document_id: int,
+    suggestions: dict,
+    *,
+    backend: str,
+    timeout: int = CACHE_50_MINUTES,
+) -> None:
+    """
+    Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
+    """
+    from documents.caching import SuggestionCacheData
+
+    doc_key = get_suggestion_cache_key(document_id)
+    cache.set(
+        doc_key,
+        SuggestionCacheData(
+            classifier_version=1000,  # Unique marker for LLM-based suggestion
+            classifier_hash=backend,
+            suggestions=suggestions,
+        ),
+        timeout,
+    )
+
+
 def get_metadata_cache_key(document_id: int) -> str:
     """
     Returns the basic key for a document's metadata
index 74d1ff3eac8676eb2c2af189eb4a0d403978cfaf..ea28d070695d82b459a5b2b99f8d1b5b3511e86b 100644 (file)
@@ -77,13 +77,20 @@ from rest_framework.viewsets import ViewSet
 
 from documents import bulk_edit
 from documents import index
+from documents.ai.llm_classifier import get_ai_document_classification
+from documents.ai.matching import match_correspondents_by_name
+from documents.ai.matching import match_document_types_by_name
+from documents.ai.matching import match_storage_paths_by_name
+from documents.ai.matching import match_tags_by_name
 from documents.bulk_download import ArchiveOnlyStrategy
 from documents.bulk_download import OriginalAndArchiveStrategy
 from documents.bulk_download import OriginalsOnlyStrategy
+from documents.caching import get_llm_suggestion_cache
 from documents.caching import get_metadata_cache
 from documents.caching import get_suggestion_cache
 from documents.caching import refresh_metadata_cache
 from documents.caching import refresh_suggestions_cache
+from documents.caching import set_llm_suggestions_cache
 from documents.caching import set_metadata_cache
 from documents.caching import set_suggestions_cache
 from documents.classifier import load_classifier
@@ -763,37 +770,84 @@ class DocumentViewSet(
         ):
             return HttpResponseForbidden("Insufficient permissions")
 
-        document_suggestions = get_suggestion_cache(doc.pk)
+        if settings.AI_CLASSIFICATION_ENABLED:
+            cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND)
 
-        if document_suggestions is not None:
-            refresh_suggestions_cache(doc.pk)
-            return Response(document_suggestions.suggestions)
+            if cached:
+                refresh_suggestions_cache(doc.pk)
+                return Response(cached.suggestions)
 
-        classifier = load_classifier()
+            llm_resp = get_ai_document_classification(doc)
+            resp_data = {
+                "title": llm_resp.get("title"),
+                "tags": [
+                    t.id
+                    for t in match_tags_by_name(llm_resp.get("tags", []), request.user)
+                ],
+                "correspondents": [
+                    c.id
+                    for c in match_correspondents_by_name(
+                        llm_resp.get("correspondents", []),
+                        request.user,
+                    )
+                ],
+                "document_types": [
+                    d.id
+                    for d in match_document_types_by_name(
+                        llm_resp.get("document_types", []),
+                    )
+                ],
+                "storage_paths": [
+                    s.id
+                    for s in match_storage_paths_by_name(
+                        llm_resp.get("storage_paths", []),
+                        request.user,
+                    )
+                ],
+                "dates": llm_resp.get("dates", []),
+            }
 
-        dates = []
-        if settings.NUMBER_OF_SUGGESTED_DATES > 0:
-            gen = parse_date_generator(doc.filename, doc.content)
-            dates = sorted(
-                {i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
-            )
+            set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)
+        else:
+            document_suggestions = get_suggestion_cache(doc.pk)
 
-        resp_data = {
-            "correspondents": [
-                c.id for c in match_correspondents(doc, classifier, request.user)
-            ],
-            "tags": [t.id for t in match_tags(doc, classifier, request.user)],
-            "document_types": [
-                dt.id for dt in match_document_types(doc, classifier, request.user)
-            ],
-            "storage_paths": [
-                dt.id for dt in match_storage_paths(doc, classifier, request.user)
-            ],
-            "dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
-        }
+            if document_suggestions is not None:
+                refresh_suggestions_cache(doc.pk)
+                return Response(document_suggestions.suggestions)
+
+            classifier = load_classifier()
+
+            dates = []
+            if settings.NUMBER_OF_SUGGESTED_DATES > 0:
+                gen = parse_date_generator(doc.filename, doc.content)
+                dates = sorted(
+                    {
+                        i
+                        for i in itertools.islice(
+                            gen,
+                            settings.NUMBER_OF_SUGGESTED_DATES,
+                        )
+                    },
+                )
+
+            resp_data = {
+                "correspondents": [
+                    c.id for c in match_correspondents(doc, classifier, request.user)
+                ],
+                "tags": [t.id for t in match_tags(doc, classifier, request.user)],
+                "document_types": [
+                    dt.id for dt in match_document_types(doc, classifier, request.user)
+                ],
+                "storage_paths": [
+                    dt.id for dt in match_storage_paths(doc, classifier, request.user)
+                ],
+                "dates": [
+                    date.strftime("%Y-%m-%d") for date in dates if date is not None
+                ],
+            }
 
-        # Cache the suggestions and the classifier hash for later
-        set_suggestions_cache(doc.pk, resp_data, classifier)
+            # Cache the suggestions and the classifier hash for later
+            set_suggestions_cache(doc.pk, resp_data, classifier)
 
         return Response(resp_data)
 
index 41146e7170f7565ba00081dc8c4d296e51ffb89c..0778effb1b96f699cd3ea2a3a8c81bd6617e3010 100644 (file)
@@ -1411,3 +1411,13 @@ OUTLOOK_OAUTH_ENABLED = bool(
     and OUTLOOK_OAUTH_CLIENT_ID
     and OUTLOOK_OAUTH_CLIENT_SECRET,
 )
+
+################################################################################
+# AI Settings                                                                  #
+################################################################################
+AI_CLASSIFICATION_ENABLED = __get_boolean("PAPERLESS_AI_CLASSIFICATION_ENABLED", "NO")
+LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai")  # or "ollama"
+LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
+LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
+OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com")
+OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")