]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Add "all" property to results 3329/head
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sat, 6 May 2023 16:54:45 +0000 (09:54 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Sat, 6 May 2023 18:31:47 +0000 (11:31 -0700)
src-ui/src/app/data/results.ts
src-ui/src/app/services/document-list-view.service.ts
src/documents/tests/test_api.py
src/paperless/views.py

index dbf99c5a1a117c0ace479fe36d7f7f200b2485e8..d29a55567cf58800f9510e8b7cc8466927aeabff 100644 (file)
@@ -2,4 +2,6 @@ export interface Results<T> {
   count: number
 
   results: T[]
+
+  all: number[]
 }
index 087cf5473b745c7027944c1a2702bcb54fa0f962..6245c6632f951bea5e53aa726076154be9c7fc98 100644 (file)
@@ -230,21 +230,14 @@ export class DocumentListViewService {
           activeListViewState.documents = result.results
 
           this.documentService
-            .listAllFilteredIds(activeListViewState.filterRules)
+            .getSelectionData(result.all)
             .pipe(first())
             .subscribe({
-              next: (ids: number[]) => {
-                this.documentService
-                  .getSelectionData(ids)
-                  .pipe(first())
-                  .subscribe({
-                    next: (selectionData) => {
-                      this.selectionData = selectionData
-                    },
-                    error: () => {
-                      this.selectionData = null
-                    },
-                  })
+              next: (selectionData) => {
+                this.selectionData = selectionData
+              },
+              error: () => {
+                this.selectionData = null
               },
             })
 
index e8c6dee7c4b4248cbaf1c8416f4fa303a3736c52..82f6f7bde4fdc2cf5bb04bb68dd795e93171c8ec 100644 (file)
@@ -499,21 +499,25 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
         results = response.data["results"]
         self.assertEqual(response.data["count"], 3)
         self.assertEqual(len(results), 3)
+        self.assertCountEqual(response.data["all"], [d1.id, d2.id, d3.id])
 
         response = self.client.get("/api/documents/?query=september")
         results = response.data["results"]
         self.assertEqual(response.data["count"], 1)
         self.assertEqual(len(results), 1)
+        self.assertCountEqual(response.data["all"], [d3.id])
 
         response = self.client.get("/api/documents/?query=statement")
         results = response.data["results"]
         self.assertEqual(response.data["count"], 2)
         self.assertEqual(len(results), 2)
+        self.assertCountEqual(response.data["all"], [d2.id, d3.id])
 
         response = self.client.get("/api/documents/?query=sfegdfg")
         results = response.data["results"]
         self.assertEqual(response.data["count"], 0)
         self.assertEqual(len(results), 0)
+        self.assertCountEqual(response.data["all"], [])
 
     def test_search_multi_page(self):
         with AsyncWriter(index.open_index()) as writer:
@@ -1230,6 +1234,31 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
             [d1.id, d3.id, d2.id],
         )
 
+    def test_pagination_all(self):
+        """
+        GIVEN:
+            - A set of 50 documents
+        WHEN:
+            - API reuqest for document filtering
+        THEN:
+            - Results are paginated (25 items) and response["all"] returns all ids (50 items)
+        """
+        t = Tag.objects.create(name="tag")
+        docs = []
+        for i in range(50):
+            d = Document.objects.create(checksum=i, content=f"test{i}")
+            d.tags.add(t)
+            docs.append(d)
+
+        response = self.client.get(
+            f"/api/documents/?tags__id__in={t.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 25)
+        self.assertEqual(len(response.data["all"]), 50)
+        self.assertCountEqual(response.data["all"], [d.id for d in docs])
+
     def test_statistics(self):
         doc1 = Document.objects.create(
             title="none1",
index 588b534e3233d7f519ce756b6201ecd174346b9c..777641f753b7388e165744bdbc422a26f86951ca 100644 (file)
@@ -1,4 +1,5 @@
 import os
+from collections import OrderedDict
 
 from django.contrib.auth.models import Group
 from django.contrib.auth.models import User
@@ -9,6 +10,7 @@ from django_filters.rest_framework import DjangoFilterBackend
 from rest_framework.filters import OrderingFilter
 from rest_framework.pagination import PageNumberPagination
 from rest_framework.permissions import IsAuthenticated
+from rest_framework.response import Response
 from rest_framework.viewsets import ModelViewSet
 
 from documents.permissions import PaperlessObjectPermissions
@@ -23,6 +25,47 @@ class StandardPagination(PageNumberPagination):
     page_size_query_param = "page_size"
     max_page_size = 100000
 
+    def get_paginated_response(self, data):
+        return Response(
+            OrderedDict(
+                [
+                    ("count", self.page.paginator.count),
+                    ("next", self.get_next_link()),
+                    ("previous", self.get_previous_link()),
+                    ("all", self.get_all_result_ids()),
+                    ("results", data),
+                ],
+            ),
+        )
+
+    def get_all_result_ids(self):
+        ids = []
+        if hasattr(self.page.paginator.object_list, "saved_results"):
+            results_page = self.page.paginator.object_list.saved_results[0]
+            if results_page is not None:
+                for i in range(0, len(results_page.results.docs())):
+                    try:
+                        fields = results_page.results.fields(i)
+                        if "id" in fields:
+                            ids.append(fields["id"])
+                    except Exception:
+                        pass
+        else:
+            for obj in self.page.paginator.object_list:
+                if hasattr(obj, "id"):
+                    ids.append(obj.id)
+                elif hasattr(obj, "fields"):
+                    ids.append(obj.fields()["id"])
+        return ids
+
+    def get_paginated_response_schema(self, schema):
+        response_schema = super().get_paginated_response_schema(schema)
+        response_schema["properties"]["all"] = {
+            "type": "array",
+            "example": "[1, 2, 3]",
+        }
+        return response_schema
+
 
 class FaviconView(View):
     def get(self, request, *args, **kwargs):  # pragma: nocover