]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Refactor: Reduce number of SQL queries when serializing List[Document] (#7505)
authorYichi Yang <helium6072@gmail.com>
Sun, 25 Aug 2024 04:20:24 +0000 (12:20 +0800)
committerGitHub <noreply@github.com>
Sun, 25 Aug 2024 04:20:24 +0000 (21:20 -0700)
src/documents/serialisers.py

index 0c0813aa4fae9dfe8f52ac72155dc83e25ef1468..747d744b65cacfb2068560793634f9022f6b0968 100644 (file)
@@ -2,6 +2,7 @@ import datetime
 import math
 import re
 import zoneinfo
+from collections.abc import Iterable
 from decimal import Decimal
 
 import magic
@@ -232,22 +233,45 @@ class OwnedObjectSerializer(
             )
         )
 
-    def get_is_shared_by_requester(self, obj: Document):
-        ctype = ContentType.objects.get_for_model(obj)
+    @staticmethod
+    def get_shared_object_pks(objects: Iterable):
+        """
+        Return the primary keys of the subset of objects that are shared.
+        """
+        try:
+            first_obj = next(iter(objects))
+        except StopIteration:
+            return set()
+
+        ctype = ContentType.objects.get_for_model(first_obj)
+        object_pks = list(obj.pk for obj in objects)
+        pk_type = type(first_obj.pk)
+
+        def get_pks_for_permission_type(model):
+            return map(
+                pk_type,  # coerce the pk to be the same type of the provided objects
+                model.objects.filter(
+                    content_type=ctype,
+                    object_pk__in=object_pks,
+                )
+                .values_list("object_pk", flat=True)
+                .distinct(),
+            )
+
         UserObjectPermission = get_user_obj_perms_model()
         GroupObjectPermission = get_group_obj_perms_model()
-        return obj.owner == self.user and (
-            UserObjectPermission.objects.filter(
-                content_type=ctype,
-                object_pk=obj.pk,
-            ).count()
-            > 0
-            or GroupObjectPermission.objects.filter(
-                content_type=ctype,
-                object_pk=obj.pk,
-            ).count()
-            > 0
-        )
+        user_permission_pks = get_pks_for_permission_type(UserObjectPermission)
+        group_permission_pks = get_pks_for_permission_type(GroupObjectPermission)
+
+        return set(user_permission_pks) | set(group_permission_pks)
+
+    def get_is_shared_by_requester(self, obj: Document):
+        # First check the context to see if `shared_object_pks` is set by the parent.
+        shared_object_pks = self.context.get("shared_object_pks")
+        # If not just check if the current object is shared.
+        if shared_object_pks is None:
+            shared_object_pks = self.get_shared_object_pks([obj])
+        return obj.owner == self.user and obj.id in shared_object_pks
 
     permissions = SerializerMethodField(read_only=True)
     user_can_change = SerializerMethodField(read_only=True)
@@ -303,6 +327,14 @@ class OwnedObjectSerializer(
         return super().update(instance, validated_data)
 
 
+class OwnedObjectListSerializer(serializers.ListSerializer):
+    def to_representation(self, documents):
+        self.child.context["shared_object_pks"] = self.child.get_shared_object_pks(
+            documents,
+        )
+        return super().to_representation(documents)
+
+
 class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer):
     last_correspondence = serializers.DateTimeField(read_only=True, required=False)
 
@@ -863,35 +895,69 @@ class DocumentSerializer(
             "custom_fields",
             "remove_inbox_tags",
         )
+        list_serializer_class = OwnedObjectListSerializer
+
+
+class SearchResultListSerializer(serializers.ListSerializer):
+    def to_representation(self, hits):
+        document_ids = [hit["id"] for hit in hits]
+        # Fetch all Document objects in the list in one SQL query.
+        documents = self.child.fetch_documents(document_ids)
+        self.child.context["documents"] = documents
+        # Also check if they are shared with other users / groups.
+        self.child.context["shared_object_pks"] = self.child.get_shared_object_pks(
+            documents.values(),
+        )
+
+        return super().to_representation(hits)
 
 
 class SearchResultSerializer(DocumentSerializer):
-    def to_representation(self, instance):
-        doc = (
-            Document.objects.select_related(
+    @staticmethod
+    def fetch_documents(ids):
+        """
+        Return a dict that maps given document IDs to Document objects.
+        """
+        return {
+            document.id: document
+            for document in Document.objects.select_related(
                 "correspondent",
                 "storage_path",
                 "document_type",
                 "owner",
             )
             .prefetch_related("tags", "custom_fields", "notes")
-            .get(id=instance["id"])
-        )
+            .filter(id__in=ids)
+        }
+
+    def to_representation(self, hit):
+        # Again we first check if the parent has already fetched the documents.
+        documents = self.context.get("documents")
+        # Otherwise we fetch this document.
+        if documents is None:  # pragma: no cover
+            # In practice we only serialize **lists** of whoosh.searching.Hit.
+            # I'm keeping this check for completeness but marking it no cover for now.
+            documents = self.fetch_documents([hit["id"]])
+        document = documents[hit["id"]]
+
         notes = ",".join(
-            [str(c.note) for c in doc.notes.all()],
+            [str(c.note) for c in document.notes.all()],
         )
-        r = super().to_representation(doc)
+        r = super().to_representation(document)
         r["__search_hit__"] = {
-            "score": instance.score,
-            "highlights": instance.highlights("content", text=doc.content),
+            "score": hit.score,
+            "highlights": hit.highlights("content", text=document.content),
             "note_highlights": (
-                instance.highlights("notes", text=notes) if doc else None
+                hit.highlights("notes", text=notes) if document else None
             ),
-            "rank": instance.rank,
+            "rank": hit.rank,
         }
 
         return r
 
+    class Meta(DocumentSerializer.Meta):
+        list_serializer_class = SearchResultListSerializer
+
 
 class SavedViewFilterRuleSerializer(serializers.ModelSerializer):
     class Meta: