]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Use a frontend config
authorshamoon <4887959+shamoon@users.noreply.github.com>
Thu, 24 Apr 2025 02:24:32 +0000 (19:24 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:51 +0000 (11:01 -0700)
src-ui/src/app/data/paperless-config.ts
src/documents/tests/test_api_app_config.py
src/documents/views.py
src/paperless/ai/ai_classifier.py
src/paperless/ai/client.py
src/paperless/config.py
src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py [new file with mode: 0644]
src/paperless/models.py
src/paperless/settings.py
src/paperless/tests/test_ai_classifier.py
src/paperless/tests/test_ai_client.py

index 3afca66ffae33ffe8f9d7d72e3e2a144ef30f87a..1d8f27b33f0c3757b1a568bc46f2651581a2dbae 100644 (file)
@@ -50,6 +50,7 @@ export const ConfigCategory = {
   General: $localize`General Settings`,
   OCR: $localize`OCR Settings`,
   Barcode: $localize`Barcode Settings`,
+  AI: $localize`AI Settings`,
 }
 
 export interface ConfigOption {
@@ -257,6 +258,39 @@ export const PaperlessConfigOptions: ConfigOption[] = [
     type: ConfigOptionType.JSON,
     config_key: 'PAPERLESS_CONSUMER_TAG_BARCODE_MAPPING',
     category: ConfigCategory.Barcode,
+    key: 'ai_enabled',
+    title: $localize`AI Enabled`,
+    type: ConfigOptionType.Boolean,
+    config_key: 'PAPERLESS_AI_ENABLED',
+    category: ConfigCategory.AI,
+  },
+  {
+    key: 'llm_backend',
+    title: $localize`LLM Backend`,
+    type: ConfigOptionType.String,
+    config_key: 'PAPERLESS_LLM_BACKEND',
+    category: ConfigCategory.AI,
+  },
+  {
+    key: 'llm_model',
+    title: $localize`LLM Model`,
+    type: ConfigOptionType.String,
+    config_key: 'PAPERLESS_LLM_MODEL',
+    category: ConfigCategory.AI,
+  },
+  {
+    key: 'llm_api_key',
+    title: $localize`LLM API Key`,
+    type: ConfigOptionType.String,
+    config_key: 'PAPERLESS_LLM_API_KEY',
+    category: ConfigCategory.AI,
+  },
+  {
+    key: 'llm_url',
+    title: $localize`LLM URL`,
+    type: ConfigOptionType.String,
+    config_key: 'PAPERLESS_LLM_URL',
+    category: ConfigCategory.AI,
   },
 ]
 
@@ -287,4 +321,9 @@ export interface PaperlessConfig extends ObjectWithId {
   barcode_max_pages: number
   barcode_enable_tag: boolean
   barcode_tag_mapping: object
+  ai_enabled: boolean
+  llm_backend: string
+  llm_model: string
+  llm_api_key: string
+  llm_url: string
 }
index 5968b16701f2b63b8286f85532457c1a85862ef2..502a22fcd14b1ca17808eb91c23b174acc9e729d 100644 (file)
@@ -64,6 +64,11 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
                 "barcode_max_pages": None,
                 "barcode_enable_tag": None,
                 "barcode_tag_mapping": None,
+                "ai_enabled": False,
+                "llm_backend": None,
+                "llm_model": None,
+                "llm_api_key": None,
+                "llm_url": None,
             },
         )
 
index ff721512d2bacec990be2d7dd1bcb782c6e95bca..dad9f560eb5042331aca142a271ecda42b93224a 100644 (file)
@@ -179,6 +179,7 @@ from paperless.ai.matching import match_document_types_by_name
 from paperless.ai.matching import match_storage_paths_by_name
 from paperless.ai.matching import match_tags_by_name
 from paperless.celery import app as celery_app
+from paperless.config import AIConfig
 from paperless.config import GeneralConfig
 from paperless.db import GnuPG
 from paperless.serialisers import GroupSerializer
@@ -771,10 +772,12 @@ class DocumentViewSet(
         ):
             return HttpResponseForbidden("Insufficient permissions")
 
-        if settings.AI_ENABLED:
+        ai_config = AIConfig()
+
+        if ai_config.ai_enabled:
             cached_llm_suggestions = get_llm_suggestion_cache(
                 doc.pk,
-                backend=settings.LLM_BACKEND,
+                backend=ai_config.llm_backend,
             )
 
             if cached_llm_suggestions:
@@ -825,7 +828,7 @@ class DocumentViewSet(
                 "dates": llm_suggestions.get("dates", []),
             }
 
-            set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)
+            set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
         else:
             document_suggestions = get_suggestion_cache(doc.pk)
 
@@ -2279,7 +2282,10 @@ class UiSettingsView(GenericAPIView):
                 request.session["oauth_state"] = manager.state
 
         ui_settings["email_enabled"] = settings.EMAIL_ENABLED
-        ui_settings["ai_enabled"] = settings.AI_ENABLED
+
+        ai_config = AIConfig()
+
+        ui_settings["ai_enabled"] = ai_config.ai_enabled
 
         user_resp = {
             "id": user.id,
index 71eae8bacaa731ea494619ef66da4cdc4e52765d..949cfaf696c1488cd42ff6858ff2d97e29ee329d 100644 (file)
@@ -2,7 +2,7 @@ import json
 import logging
 
 from documents.models import Document
-from paperless.ai.client import run_llm_query
+from paperless.ai.client import AIClient
 
 logger = logging.getLogger("paperless.ai.ai_classifier")
 
@@ -49,7 +49,8 @@ def get_ai_document_classification(document: Document) -> dict:
     """
 
     try:
-        result = run_llm_query(prompt)
+        client = AIClient()
+        result = client.run_llm_query(prompt)
         suggestions = parse_ai_classification_response(result)
         return suggestions or {}
     except Exception:
index 13bf680bc354f9634105b6979388776767156fcf..03012844f8eff213ecee04af4b432b0313a3c458 100644 (file)
@@ -1,58 +1,70 @@
 import logging
 
 import httpx
-from django.conf import settings
+
+from paperless.config import AIConfig
 
 logger = logging.getLogger("paperless.ai.client")
 
 
-def run_llm_query(prompt: str) -> str:
-    logger.debug(
-        "Running LLM query against %s with model %s",
-        settings.LLM_BACKEND,
-        settings.LLM_MODEL,
-    )
-    match settings.LLM_BACKEND:
-        case "openai":
-            result = _run_openai_query(prompt)
-        case "ollama":
-            result = _run_ollama_query(prompt)
-        case _:
-            raise ValueError(f"Unsupported LLM backend: {settings.LLM_BACKEND}")
-    logger.debug("LLM query result: %s", result)
-    return result
-
-
-def _run_ollama_query(prompt: str) -> str:
-    with httpx.Client(timeout=30.0) as client:
-        response = client.post(
-            f"{settings.OLLAMA_URL}/api/chat",
-            json={
-                "model": settings.LLM_MODEL,
-                "messages": [{"role": "user", "content": prompt}],
-                "stream": False,
-            },
-        )
-        response.raise_for_status()
-        return response.json()["message"]["content"]
-
-
-def _run_openai_query(prompt: str) -> str:
-    if not settings.LLM_API_KEY:
-        raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
-
-    with httpx.Client(timeout=30.0) as client:
-        response = client.post(
-            f"{settings.OPENAI_URL}/v1/chat/completions",
-            headers={
-                "Authorization": f"Bearer {settings.LLM_API_KEY}",
-                "Content-Type": "application/json",
-            },
-            json={
-                "model": settings.LLM_MODEL,
-                "messages": [{"role": "user", "content": prompt}],
-                "temperature": 0.3,
-            },
+class AIClient:
+    """
+    A client for interacting with an LLM backend.
+    """
+
+    def __init__(self):
+        self.settings = AIConfig()
+
+    def run_llm_query(self, prompt: str) -> str:
+        logger.debug(
+            "Running LLM query against %s with model %s",
+            self.settings.llm_backend,
+            self.settings.llm_model,
         )
-        response.raise_for_status()
-        return response.json()["choices"][0]["message"]["content"]
+        match self.settings.llm_backend:
+            case "openai":
+                result = self._run_openai_query(prompt)
+            case "ollama":
+                result = self._run_ollama_query(prompt)
+            case _:
+                raise ValueError(
+                    f"Unsupported LLM backend: {self.settings.llm_backend}",
+                )
+        logger.debug("LLM query result: %s", result)
+        return result
+
+    def _run_ollama_query(self, prompt: str) -> str:
+        url = self.settings.llm_url or "http://localhost:11434"
+        with httpx.Client(timeout=30.0) as client:
+            response = client.post(
+                f"{url}/api/chat",
+                json={
+                    "model": self.settings.llm_model,
+                    "messages": [{"role": "user", "content": prompt}],
+                    "stream": False,
+                },
+            )
+            response.raise_for_status()
+            return response.json()["message"]["content"]
+
+    def _run_openai_query(self, prompt: str) -> str:
+        if not self.settings.llm_api_key:
+            raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
+
+        url = self.settings.llm_url or "https://api.openai.com"
+
+        with httpx.Client(timeout=30.0) as client:
+            response = client.post(
+                f"{url}/v1/chat/completions",
+                headers={
+                    "Authorization": f"Bearer {self.settings.llm_api_key}",
+                    "Content-Type": "application/json",
+                },
+                json={
+                    "model": self.settings.llm_model,
+                    "messages": [{"role": "user", "content": prompt}],
+                    "temperature": 0.3,
+                },
+            )
+            response.raise_for_status()
+            return response.json()["choices"][0]["message"]["content"]
index fb3139d7943e8916e999a75164cd65c36da7933e..4a20ea461e21b1fa217edeac117bad72b47e1379 100644 (file)
@@ -169,3 +169,25 @@ class GeneralConfig(BaseConfig):
 
         self.app_title = app_config.app_title or None
         self.app_logo = app_config.app_logo.url if app_config.app_logo else None
+
+
+@dataclasses.dataclass
+class AIConfig(BaseConfig):
+    """
+    AI related settings that require global scope
+    """
+
+    ai_enabled: bool = dataclasses.field(init=False)
+    llm_backend: str = dataclasses.field(init=False)
+    llm_model: str = dataclasses.field(init=False)
+    llm_api_key: str = dataclasses.field(init=False)
+    llm_url: str = dataclasses.field(init=False)
+
+    def __post_init__(self) -> None:
+        app_config = self._get_config_instance()
+
+        self.ai_enabled = app_config.ai_enabled or settings.AI_ENABLED
+        self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
+        self.llm_model = app_config.llm_model or settings.LLM_MODEL
+        self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
+        self.llm_url = app_config.llm_url or settings.LLM_URL
diff --git a/src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py b/src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py
new file mode 100644 (file)
index 0000000..55833df
--- /dev/null
@@ -0,0 +1,63 @@
+# Generated by Django 5.1.7 on 2025-04-24 02:09
+
+from django.db import migrations
+from django.db import models
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ("paperless", "0003_alter_applicationconfiguration_max_image_pixels"),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name="applicationconfiguration",
+            name="ai_enabled",
+            field=models.BooleanField(
+                default=False,
+                null=True,
+                verbose_name="Enables AI features",
+            ),
+        ),
+        migrations.AddField(
+            model_name="applicationconfiguration",
+            name="llm_api_key",
+            field=models.CharField(
+                blank=True,
+                max_length=128,
+                null=True,
+                verbose_name="Sets the LLM API key",
+            ),
+        ),
+        migrations.AddField(
+            model_name="applicationconfiguration",
+            name="llm_backend",
+            field=models.CharField(
+                blank=True,
+                choices=[("openai", "OpenAI"), ("ollama", "Ollama")],
+                max_length=32,
+                null=True,
+                verbose_name="Sets the LLM backend",
+            ),
+        ),
+        migrations.AddField(
+            model_name="applicationconfiguration",
+            name="llm_model",
+            field=models.CharField(
+                blank=True,
+                max_length=32,
+                null=True,
+                verbose_name="Sets the LLM model",
+            ),
+        ),
+        migrations.AddField(
+            model_name="applicationconfiguration",
+            name="llm_url",
+            field=models.CharField(
+                blank=True,
+                max_length=128,
+                null=True,
+                verbose_name="Sets the LLM URL, optional",
+            ),
+        ),
+    ]
index 1c44f1414f748f143de9b382fd5018acc6a74179..d2a15d9ac180ddfd71226da0ccd91da572596e6a 100644 (file)
@@ -74,6 +74,15 @@ class ColorConvertChoices(models.TextChoices):
     CMYK = ("CMYK", _("CMYK"))
 
 
+class LLMBackend(models.TextChoices):
+    """
+    Matches to --llm-backend
+    """
+
+    OPENAI = ("openai", _("OpenAI"))
+    OLLAMA = ("ollama", _("Ollama"))
+
+
 class ApplicationConfiguration(AbstractSingletonModel):
     """
     Settings which are common across more than 1 parser
@@ -265,6 +274,45 @@ class ApplicationConfiguration(AbstractSingletonModel):
         null=True,
     )
 
+    """
+    AI related settings
+    """
+
+    ai_enabled = models.BooleanField(
+        verbose_name=_("Enables AI features"),
+        null=True,
+        default=False,
+    )
+
+    llm_backend = models.CharField(
+        verbose_name=_("Sets the LLM backend"),
+        null=True,
+        blank=True,
+        max_length=32,
+        choices=LLMBackend.choices,
+    )
+
+    llm_model = models.CharField(
+        verbose_name=_("Sets the LLM model"),
+        null=True,
+        blank=True,
+        max_length=32,
+    )
+
+    llm_api_key = models.CharField(
+        verbose_name=_("Sets the LLM API key"),
+        null=True,
+        blank=True,
+        max_length=128,
+    )
+
+    llm_url = models.CharField(
+        verbose_name=_("Sets the LLM URL, optional"),
+        null=True,
+        blank=True,
+        max_length=128,
+    )
+
     class Meta:
         verbose_name = _("paperless application settings")
 
index c1281af432484c8e1b103946553fd1dac0cbaeaf..6c4b58e4bc0103ab56399cb55be20da84c423cc1 100644 (file)
@@ -1419,5 +1419,4 @@ AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
 LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai")  # or "ollama"
 LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
 LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
-OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com")
-OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")
+LLM_URL = os.getenv("PAPERLESS_LLM_URL")
index 57686fee601b1c81987080e578d1c25288a36332..edb086bbee297f0607c6d56d0cdbdce2c81654e2 100644 (file)
@@ -13,7 +13,8 @@ def mock_document():
     return Document(filename="test.pdf", content="This is a test document content.")
 
 
-@patch("paperless.ai.ai_classifier.run_llm_query")
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
 def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
     mock_response = json.dumps(
         {
@@ -37,7 +38,8 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
     assert result["dates"] == ["2023-01-01"]
 
 
-@patch("paperless.ai.ai_classifier.run_llm_query")
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient.run_llm_query")
 def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
     mock_run_llm_query.side_effect = Exception("LLM query failed")
 
index 6a332de27decba73be0e24c4d74beca233675315..6a239279ec00b7e720ba8b43c26c948434303296 100644 (file)
@@ -4,9 +4,7 @@ from unittest.mock import patch
 import pytest
 from django.conf import settings
 
-from paperless.ai.client import _run_ollama_query
-from paperless.ai.client import _run_openai_query
-from paperless.ai.client import run_llm_query
+from paperless.ai.client import AIClient
 
 
 @pytest.fixture
@@ -14,52 +12,59 @@ def mock_settings():
     settings.LLM_BACKEND = "openai"
     settings.LLM_MODEL = "gpt-3.5-turbo"
     settings.LLM_API_KEY = "test-api-key"
-    settings.OPENAI_URL = "https://api.openai.com"
-    settings.OLLAMA_URL = "https://ollama.example.com"
     yield settings
 
 
-@patch("paperless.ai.client._run_openai_query")
-@patch("paperless.ai.client._run_ollama_query")
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient._run_openai_query")
+@patch("paperless.ai.client.AIClient._run_ollama_query")
 def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
+    mock_settings.LLM_BACKEND = "openai"
     mock_openai_query.return_value = "OpenAI response"
-    result = run_llm_query("Test prompt")
+    client = AIClient()
+    result = client.run_llm_query("Test prompt")
     assert result == "OpenAI response"
     mock_openai_query.assert_called_once_with("Test prompt")
     mock_ollama_query.assert_not_called()
 
 
-@patch("paperless.ai.client._run_openai_query")
-@patch("paperless.ai.client._run_ollama_query")
+@pytest.mark.django_db
+@patch("paperless.ai.client.AIClient._run_openai_query")
+@patch("paperless.ai.client.AIClient._run_ollama_query")
 def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
     mock_settings.LLM_BACKEND = "ollama"
     mock_ollama_query.return_value = "Ollama response"
-    result = run_llm_query("Test prompt")
+    client = AIClient()
+    result = client.run_llm_query("Test prompt")
     assert result == "Ollama response"
     mock_ollama_query.assert_called_once_with("Test prompt")
     mock_openai_query.assert_not_called()
 
 
+@pytest.mark.django_db
 def test_run_llm_query_unsupported_backend(mock_settings):
     mock_settings.LLM_BACKEND = "unsupported"
+    client = AIClient()
     with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
-        run_llm_query("Test prompt")
+        client.run_llm_query("Test prompt")
 
 
+@pytest.mark.django_db
 def test_run_openai_query(httpx_mock, mock_settings):
+    mock_settings.LLM_BACKEND = "openai"
     httpx_mock.add_response(
-        url=f"{mock_settings.OPENAI_URL}/v1/chat/completions",
+        url="https://api.openai.com/v1/chat/completions",
         json={
             "choices": [{"message": {"content": "OpenAI response"}}],
         },
     )
 
-    result = _run_openai_query("Test prompt")
+    client = AIClient()
+    result = client.run_llm_query("Test prompt")
     assert result == "OpenAI response"
 
     request = httpx_mock.get_request()
     assert request.method == "POST"
-    assert request.url == f"{mock_settings.OPENAI_URL}/v1/chat/completions"
     assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
     assert request.headers["Content-Type"] == "application/json"
     assert json.loads(request.content) == {
@@ -69,18 +74,20 @@ def test_run_openai_query(httpx_mock, mock_settings):
     }
 
 
+@pytest.mark.django_db
 def test_run_ollama_query(httpx_mock, mock_settings):
+    mock_settings.LLM_BACKEND = "ollama"
     httpx_mock.add_response(
-        url=f"{mock_settings.OLLAMA_URL}/api/chat",
+        url="http://localhost:11434/api/chat",
         json={"message": {"content": "Ollama response"}},
     )
 
-    result = _run_ollama_query("Test prompt")
+    client = AIClient()
+    result = client.run_llm_query("Test prompt")
     assert result == "Ollama response"
 
     request = httpx_mock.get_request()
     assert request.method == "POST"
-    assert request.url == f"{mock_settings.OLLAMA_URL}/api/chat"
     assert json.loads(request.content) == {
         "model": mock_settings.LLM_MODEL,
         "messages": [{"role": "user", "content": "Test prompt"}],