]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
API support for id args for documents & objects (#4519)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Mon, 6 Nov 2023 20:31:10 +0000 (12:31 -0800)
committerGitHub <noreply@github.com>
Mon, 6 Nov 2023 20:31:10 +0000 (12:31 -0800)
src/documents/filters.py
src/documents/tests/test_api.py

index 21d120047273e1f3b0279617aca66ae0cbfa6d30..c6abff4de201e0ed21227622ff2a2e7bb5588bfa 100644 (file)
@@ -21,19 +21,38 @@ DATE_KWARGS = ["year", "month", "day", "date__gt", "gt", "date__lt", "lt"]
 class CorrespondentFilterSet(FilterSet):
     class Meta:
         model = Correspondent
-        fields = {"name": CHAR_KWARGS}
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
 
 
 class TagFilterSet(FilterSet):
     class Meta:
         model = Tag
-        fields = {"name": CHAR_KWARGS}
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
 
 
 class DocumentTypeFilterSet(FilterSet):
     class Meta:
         model = DocumentType
-        fields = {"name": CHAR_KWARGS}
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
+class StoragePathFilterSet(FilterSet):
+    class Meta:
+        model = StoragePath
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+            "path": CHAR_KWARGS,
+        }
 
 
 class ObjectFilter(Filter):
@@ -128,6 +147,7 @@ class DocumentFilterSet(FilterSet):
     class Meta:
         model = Document
         fields = {
+            "id": ID_KWARGS,
             "title": CHAR_KWARGS,
             "content": CHAR_KWARGS,
             "archive_serial_number": INT_KWARGS,
@@ -159,15 +179,6 @@ class LogFilterSet(FilterSet):
         fields = {"level": INT_KWARGS, "created": DATE_KWARGS, "group": ID_KWARGS}
 
 
-class StoragePathFilterSet(FilterSet):
-    class Meta:
-        model = StoragePath
-        fields = {
-            "name": CHAR_KWARGS,
-            "path": CHAR_KWARGS,
-        }
-
-
 class ShareLinkFilterSet(FilterSet):
     class Meta:
         model = ShareLink
index d9359ef3cd3b844b5256618758183a2da55c6b51..5dda30a8ec72f11685d7e3a8b3b5787147f1bed5 100644 (file)
@@ -450,6 +450,27 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
         results = response.data["results"]
         self.assertEqual(len(results), 0)
 
+        response = self.client.get(
+            f"/api/documents/?id__in={doc1.id},{doc2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+        response = self.client.get(
+            f"/api/documents/?id__range={doc1.id},{doc3.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 3)
+
+        response = self.client.get(
+            f"/api/documents/?id={doc2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
         # custom field name
         response = self.client.get(
             f"/api/documents/?custom_fields__icontains={cf1.name}",
@@ -4655,6 +4676,82 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase):
         )
 
 
+class TestApiObjects(DirectoriesMixin, APITestCase):
+    def setUp(self) -> None:
+        super().setUp()
+
+        user = User.objects.create_superuser(username="temp_admin")
+        self.client.force_authenticate(user=user)
+
+        self.tag1 = Tag.objects.create(name="t1", is_inbox_tag=True)
+        self.tag2 = Tag.objects.create(name="t2")
+        self.tag3 = Tag.objects.create(name="t3")
+        self.c1 = Correspondent.objects.create(name="c1")
+        self.c2 = Correspondent.objects.create(name="c2")
+        self.c3 = Correspondent.objects.create(name="c3")
+        self.dt1 = DocumentType.objects.create(name="dt1")
+        self.dt2 = DocumentType.objects.create(name="dt2")
+        self.sp1 = StoragePath.objects.create(name="sp1", path="Something/{title}")
+        self.sp2 = StoragePath.objects.create(name="sp2", path="Something2/{title}")
+
+    def test_object_filters(self):
+        response = self.client.get(
+            f"/api/tags/?id={self.tag2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
+        response = self.client.get(
+            f"/api/tags/?id__in={self.tag1.id},{self.tag3.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+        response = self.client.get(
+            f"/api/correspondents/?id={self.c2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
+        response = self.client.get(
+            f"/api/correspondents/?id__in={self.c1.id},{self.c3.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+        response = self.client.get(
+            f"/api/document_types/?id={self.dt1.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
+        response = self.client.get(
+            f"/api/document_types/?id__in={self.dt1.id},{self.dt2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+        response = self.client.get(
+            f"/api/storage_paths/?id={self.sp1.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
+        response = self.client.get(
+            f"/api/storage_paths/?id__in={self.sp1.id},{self.sp2.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+
 class TestApiStoragePaths(DirectoriesMixin, APITestCase):
     ENDPOINT = "/api/storage_paths/"