]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Object creation with owner
authorMichael Shamoon <4887959+shamoon@users.noreply.github.com>
Tue, 6 Dec 2022 07:41:17 +0000 (23:41 -0800)
committerMichael Shamoon <4887959+shamoon@users.noreply.github.com>
Tue, 6 Dec 2022 07:41:17 +0000 (23:41 -0800)
src/documents/serialisers.py
src/documents/views.py

index db282cacd1e4db2fa8e21fcd819a98fe92cbc237..553669e32e6fe5265ee3a057a53a99e661819d0e 100644 (file)
@@ -74,7 +74,18 @@ class MatchingModelSerializer(serializers.ModelSerializer):
         return match
 
 
-class CorrespondentSerializer(MatchingModelSerializer):
+class OwnedObjectSerializer(serializers.ModelSerializer):
+    def __init__(self, *args, **kwargs):
+        self.user = kwargs.pop("user", None)
+        return super().__init__(*args, **kwargs)
+
+    def create(self, validated_data):
+        if self.user and validated_data["owner"] is None:
+            validated_data["owner"] = self.user
+        return super().create(validated_data)
+
+
+class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer):
 
     last_correspondence = serializers.DateTimeField(read_only=True)
 
@@ -89,10 +100,11 @@ class CorrespondentSerializer(MatchingModelSerializer):
             "is_insensitive",
             "document_count",
             "last_correspondence",
+            "owner",
         )
 
 
-class DocumentTypeSerializer(MatchingModelSerializer):
+class DocumentTypeSerializer(MatchingModelSerializer, OwnedObjectSerializer):
     class Meta:
         model = DocumentType
         fields = (
@@ -103,6 +115,7 @@ class DocumentTypeSerializer(MatchingModelSerializer):
             "matching_algorithm",
             "is_insensitive",
             "document_count",
+            "owner",
         )
 
 
@@ -153,10 +166,11 @@ class TagSerializerVersion1(MatchingModelSerializer):
             "is_insensitive",
             "is_inbox_tag",
             "document_count",
+            "owner",
         )
 
 
-class TagSerializer(MatchingModelSerializer):
+class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
     def get_text_color(self, obj):
         try:
             h = obj.color.lstrip("#")
@@ -214,7 +228,7 @@ class StoragePathField(serializers.PrimaryKeyRelatedField):
         return StoragePath.objects.all()
 
 
-class DocumentSerializer(DynamicFieldsModelSerializer):
+class DocumentSerializer(DynamicFieldsModelSerializer, OwnedObjectSerializer):
 
     correspondent = CorrespondentField(allow_null=True)
     tags = TagsField(many=True)
@@ -265,6 +279,7 @@ class DocumentSerializer(DynamicFieldsModelSerializer):
             "archive_serial_number",
             "original_file_name",
             "archived_file_name",
+            "owner",
         )
 
 
@@ -274,7 +289,7 @@ class SavedViewFilterRuleSerializer(serializers.ModelSerializer):
         fields = ["rule_type", "value"]
 
 
-class SavedViewSerializer(serializers.ModelSerializer):
+class SavedViewSerializer(OwnedObjectSerializer):
 
     filter_rules = SavedViewFilterRuleSerializer(many=True)
 
@@ -289,6 +304,7 @@ class SavedViewSerializer(serializers.ModelSerializer):
             "sort_field",
             "sort_reverse",
             "filter_rules",
+            "owner",
         ]
 
     def update(self, instance, validated_data):
@@ -562,7 +578,7 @@ class BulkDownloadSerializer(DocumentListSerializer):
         }[compression]
 
 
-class StoragePathSerializer(MatchingModelSerializer):
+class StoragePathSerializer(MatchingModelSerializer, OwnedObjectSerializer):
     class Meta:
         model = StoragePath
         fields = (
@@ -574,6 +590,7 @@ class StoragePathSerializer(MatchingModelSerializer):
             "matching_algorithm",
             "is_insensitive",
             "document_count",
+            "owner",
         )
 
     def validate_path(self, path):
index 2a8881376a2c99f9e2020cf59b356466c4a4032b..8b01f0be1c2644bc10492144e707ccac96ec3c30 100644 (file)
@@ -42,6 +42,7 @@ from rest_framework.exceptions import NotFound
 from rest_framework.filters import OrderingFilter
 from rest_framework.filters import SearchFilter
 from rest_framework.generics import GenericAPIView
+from rest_framework.mixins import CreateModelMixin
 from rest_framework.mixins import DestroyModelMixin
 from rest_framework.mixins import ListModelMixin
 from rest_framework.mixins import RetrieveModelMixin
@@ -137,7 +138,17 @@ class IndexView(TemplateView):
         return context
 
 
-class CorrespondentViewSet(ModelViewSet):
+class PassUserMixin(CreateModelMixin):
+    """
+    Pass a user object to serializer
+    """
+
+    def get_serializer(self, *args, **kwargs):
+        kwargs.setdefault("user", self.request.user)
+        return super().get_serializer(*args, **kwargs)
+
+
+class CorrespondentViewSet(ModelViewSet, PassUserMixin):
     model = Correspondent
 
     queryset = Correspondent.objects.annotate(
@@ -163,7 +174,7 @@ class CorrespondentViewSet(ModelViewSet):
     )
 
 
-class TagViewSet(ModelViewSet):
+class TagViewSet(ModelViewSet, PassUserMixin):
     model = Tag
 
     queryset = Tag.objects.annotate(document_count=Count("documents")).order_by(
@@ -183,7 +194,7 @@ class TagViewSet(ModelViewSet):
     ordering_fields = ("name", "matching_algorithm", "match", "document_count")
 
 
-class DocumentTypeViewSet(ModelViewSet):
+class DocumentTypeViewSet(ModelViewSet, PassUserMixin):
     model = DocumentType
 
     queryset = DocumentType.objects.annotate(
@@ -204,6 +215,7 @@ class DocumentViewSet(
     DestroyModelMixin,
     ListModelMixin,
     GenericViewSet,
+    PassUserMixin,
 ):
     model = Document
     queryset = Document.objects.all()
@@ -551,7 +563,7 @@ class LogViewSet(ViewSet):
         return Response(self.log_files)
 
 
-class SavedViewViewSet(ModelViewSet):
+class SavedViewViewSet(ModelViewSet, PassUserMixin):
     model = SavedView
 
     queryset = SavedView.objects.all()
@@ -824,7 +836,7 @@ class RemoteVersionView(GenericAPIView):
         )
 
 
-class StoragePathViewSet(ModelViewSet):
+class StoragePathViewSet(ModelViewSet, PassUserMixin):
     model = StoragePath
 
     queryset = StoragePath.objects.annotate(document_count=Count("documents")).order_by(