]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Test views, caching
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 22 Apr 2025 06:32:38 +0000 (23:32 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 2 Jul 2025 18:01:49 +0000 (11:01 -0700)
src/documents/tests/test_views.py
src/documents/views.py

index 4c987e3af361dbdf2c4e7a2cc2ea6aabb544f558..ef5a71e010077db46c44da67fd26925b3aeaeac7 100644 (file)
@@ -1,6 +1,8 @@
 import tempfile
 from datetime import timedelta
 from pathlib import Path
+from unittest.mock import MagicMock
+from unittest.mock import patch
 
 from django.conf import settings
 from django.contrib.auth.models import Permission
@@ -10,8 +12,15 @@ from django.test import override_settings
 from django.utils import timezone
 from rest_framework import status
 
+from documents.caching import get_llm_suggestion_cache
+from documents.caching import set_llm_suggestions_cache
+from documents.models import Correspondent
 from documents.models import Document
+from documents.models import DocumentType
 from documents.models import ShareLink
+from documents.models import StoragePath
+from documents.models import Tag
+from documents.signals.handlers import update_llm_suggestions_cache
 from documents.tests.utils import DirectoriesMixin
 from paperless.models import ApplicationConfiguration
 
@@ -154,3 +163,104 @@ class TestViews(DirectoriesMixin, TestCase):
         response.render()
         self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
         self.assertContains(response, b"Share link has expired")
+
+
+class TestAISuggestions(DirectoriesMixin, TestCase):
+    def setUp(self):
+        self.user = User.objects.create_superuser(username="testuser")
+        self.document = Document.objects.create(
+            title="Test Document",
+            filename="test.pdf",
+            mime_type="application/pdf",
+        )
+        self.tag1 = Tag.objects.create(name="tag1")
+        self.correspondent1 = Correspondent.objects.create(name="correspondent1")
+        self.document_type1 = DocumentType.objects.create(name="type1")
+        self.path1 = StoragePath.objects.create(name="path1")
+        super().setUp()
+
+    @patch("documents.views.get_llm_suggestion_cache")
+    @patch("documents.views.refresh_suggestions_cache")
+    @override_settings(
+        AI_ENABLED=True,
+        LLM_BACKEND="mock_backend",
+    )
+    def test_suggestions_with_cached_llm(self, mock_refresh_cache, mock_get_cache):
+        mock_get_cache.return_value = MagicMock(suggestions={"tags": ["tag1", "tag2"]})
+
+        self.client.force_login(user=self.user)
+        response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
+        mock_refresh_cache.assert_called_once_with(self.document.pk)
+
+    @patch("documents.views.get_ai_document_classification")
+    @override_settings(
+        AI_ENABLED=True,
+        LLM_BACKEND="mock_backend",
+    )
+    def test_suggestions_with_ai_enabled(
+        self,
+        mock_get_ai_classification,
+    ):
+        mock_get_ai_classification.return_value = {
+            "title": "AI Title",
+            "tags": ["tag1", "tag2"],
+            "correspondents": ["correspondent1"],
+            "document_types": ["type1"],
+            "storage_paths": ["path1"],
+            "dates": ["2023-01-01"],
+        }
+
+        self.client.force_login(user=self.user)
+        response = self.client.get(f"/api/documents/{self.document.pk}/suggestions/")
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(
+            response.json(),
+            {
+                "title": "AI Title",
+                "tags": [self.tag1.pk],
+                "suggested_tags": ["tag2"],
+                "correspondents": [self.correspondent1.pk],
+                "suggested_correspondents": [],
+                "document_types": [self.document_type1.pk],
+                "suggested_document_types": [],
+                "storage_paths": [self.path1.pk],
+                "suggested_storage_paths": [],
+                "dates": ["2023-01-01"],
+            },
+        )
+
+    def test_invalidate_suggestions_cache(self):
+        self.client.force_login(user=self.user)
+        suggestions = {
+            "title": "AI Title",
+            "tags": ["tag1", "tag2"],
+            "correspondents": ["correspondent1"],
+            "document_types": ["type1"],
+            "storage_paths": ["path1"],
+            "dates": ["2023-01-01"],
+        }
+        set_llm_suggestions_cache(
+            self.document.pk,
+            suggestions,
+            backend="mock_backend",
+        )
+        self.assertEqual(
+            get_llm_suggestion_cache(
+                self.document.pk,
+                backend="mock_backend",
+            ).suggestions,
+            suggestions,
+        )
+        # post_save signal triggered
+        update_llm_suggestions_cache(
+            sender=None,
+            instance=self.document,
+        )
+        self.assertIsNone(
+            get_llm_suggestion_cache(
+                self.document.pk,
+                backend="mock_backend",
+            ),
+        )
index 41ce6bf9d54e8f1c98cb56d86055bc6851b95234..d2bf03f0c2513a42833b6cc084c8d57f1c31e29a 100644 (file)
@@ -772,51 +772,57 @@ class DocumentViewSet(
             return HttpResponseForbidden("Insufficient permissions")
 
         if settings.AI_ENABLED:
-            cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND)
+            cached_llm_suggestions = get_llm_suggestion_cache(
+                doc.pk,
+                backend=settings.LLM_BACKEND,
+            )
 
-            if cached:
+            if cached_llm_suggestions:
                 refresh_suggestions_cache(doc.pk)
-                return Response(cached.suggestions)
+                return Response(cached_llm_suggestions.suggestions)
 
-            llm_resp = get_ai_document_classification(doc)
+            llm_suggestions = get_ai_document_classification(doc)
 
-            matched_tags = match_tags_by_name(llm_resp.get("tags", []), request.user)
+            matched_tags = match_tags_by_name(
+                llm_suggestions.get("tags", []),
+                request.user,
+            )
             matched_correspondents = match_correspondents_by_name(
-                llm_resp.get("correspondents", []),
+                llm_suggestions.get("correspondents", []),
                 request.user,
             )
             matched_types = match_document_types_by_name(
-                llm_resp.get("document_types", []),
+                llm_suggestions.get("document_types", []),
                 request.user,
             )
             matched_paths = match_storage_paths_by_name(
-                llm_resp.get("storage_paths", []),
+                llm_suggestions.get("storage_paths", []),
                 request.user,
             )
 
             resp_data = {
-                "title": llm_resp.get("title"),
+                "title": llm_suggestions.get("title"),
                 "tags": [t.id for t in matched_tags],
                 "suggested_tags": extract_unmatched_names(
-                    llm_resp.get("tags", []),
+                    llm_suggestions.get("tags", []),
                     matched_tags,
                 ),
                 "correspondents": [c.id for c in matched_correspondents],
                 "suggested_correspondents": extract_unmatched_names(
-                    llm_resp.get("correspondents", []),
+                    llm_suggestions.get("correspondents", []),
                     matched_correspondents,
                 ),
                 "document_types": [d.id for d in matched_types],
                 "suggested_document_types": extract_unmatched_names(
-                    llm_resp.get("document_types", []),
+                    llm_suggestions.get("document_types", []),
                     matched_types,
                 ),
                 "storage_paths": [s.id for s in matched_paths],
                 "suggested_storage_paths": extract_unmatched_names(
-                    llm_resp.get("storage_paths", []),
+                    llm_suggestions.get("storage_paths", []),
                     matched_paths,
                 ),
-                "dates": llm_resp.get("dates", []),
+                "dates": llm_suggestions.get("dates", []),
             }
 
             set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)