]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Enhancement: add basic filters for listing custom fields (#5257)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sat, 6 Jan 2024 03:04:31 +0000 (19:04 -0800)
committerGitHub <noreply@github.com>
Sat, 6 Jan 2024 03:04:31 +0000 (03:04 +0000)
src/documents/filters.py
src/documents/tests/test_api_custom_fields.py
src/documents/views.py

index c63484ee28d2956e9c6222c01bf10624297c5095..bab20a4dc0fa53db52ff9dfd397c6248a697b719 100644 (file)
@@ -12,6 +12,7 @@ from guardian.utils import get_user_obj_perms_model
 from rest_framework_guardian.filters import ObjectPermissionsFilter
 
 from documents.models import Correspondent
+from documents.models import CustomField
 from documents.models import Document
 from documents.models import DocumentType
 from documents.models import Log
@@ -141,6 +142,15 @@ class SharedByUser(Filter):
         )
 
 
+class CustomFieldFilterSet(FilterSet):
+    class Meta:
+        model = CustomField
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
 class CustomFieldsFilter(Filter):
     def filter(self, qs, value):
         if value:
index af16d12b1e78d983bae25924b938040adc46cf7a..cf33e2800c818dd0f6e4a463a4e9bdabc649ccce 100644 (file)
@@ -662,3 +662,39 @@ class TestCustomField(DirectoriesMixin, APITestCase):
 
         self.assertEqual(resp.status_code, status.HTTP_200_OK)
         self.assertEqual(doc5.custom_fields.first().value, [1])
+
+    def test_custom_field_filters(self):
+        custom_field_string = CustomField.objects.create(
+            name="Test Custom Field String",
+            data_type=CustomField.FieldDataType.STRING,
+        )
+        custom_field_date = CustomField.objects.create(
+            name="Test Custom Field Date",
+            data_type=CustomField.FieldDataType.DATE,
+        )
+        custom_field_int = CustomField.objects.create(
+            name="Test Custom Field Int",
+            data_type=CustomField.FieldDataType.INT,
+        )
+
+        response = self.client.get(
+            f"{self.ENDPOINT}?id={custom_field_string.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+
+        response = self.client.get(
+            f"{self.ENDPOINT}?id__in={custom_field_string.id},{custom_field_date.id}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 2)
+
+        response = self.client.get(
+            f"{self.ENDPOINT}?name__icontains=Int",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        results = response.data["results"]
+        self.assertEqual(len(results), 1)
+        self.assertEqual(results[0]["name"], custom_field_int.name)
index 83f2fc321dcccfc6eb68d894203993900f463e2d..d6b90cbfd9a77768a46e4d12b32e372b512b39cd 100644 (file)
@@ -66,6 +66,7 @@ from documents.data_models import ConsumableDocument
 from documents.data_models import DocumentMetadataOverrides
 from documents.data_models import DocumentSource
 from documents.filters import CorrespondentFilterSet
+from documents.filters import CustomFieldFilterSet
 from documents.filters import DocumentFilterSet
 from documents.filters import DocumentTypeFilterSet
 from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
@@ -1438,6 +1439,11 @@ class CustomFieldViewSet(ModelViewSet):
 
     serializer_class = CustomFieldSerializer
     pagination_class = StandardPagination
+    filter_backends = (
+        DjangoFilterBackend,
+        OrderingFilter,
+    )
+    filterset_class = CustomFieldFilterSet
 
     model = CustomField