]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Performance: Cache django-guardian permissions when counting documents (#10657)
authorAntoine Mérino <antoine.merino.dev@gmail.com>
Tue, 30 Sep 2025 16:48:44 +0000 (18:48 +0200)
committerGitHub <noreply@github.com>
Tue, 30 Sep 2025 16:48:44 +0000 (09:48 -0700)
Fixes N+1 queries in tag, correspondent, storage path, custom field,
and document type list views.
Reduces SQL queries from 160 to 9.

src/documents/serialisers.py
src/documents/tests/test_views.py
src/documents/views.py

index 1608a0e4e12383fe07b4f06541b529c12d9cdb3f..ce01920744e2c491facbab27c8b0f1d9d5d427f3 100644 (file)
@@ -6,6 +6,7 @@ import re
 from datetime import datetime
 from decimal import Decimal
 from typing import TYPE_CHECKING
+from typing import Literal
 
 import magic
 from celery import states
@@ -252,6 +253,35 @@ class OwnedObjectSerializer(
             except KeyError:
                 pass
 
+    def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]):
+        """
+        Get the given permissions from context or from django-guardian.
+
+        :param codename: The permission codename, e.g. 'view' or 'change'
+        :param target: 'users' or 'groups'
+        """
+        key = f"{target}_{codename}_perms"
+        cached = self.context.get(key, {}).get(obj.pk)
+        if cached is not None:
+            return list(cached)
+
+        # Permission not found in the context, get it from guardian
+        if target == "users":
+            return list(
+                get_users_with_perms(
+                    obj,
+                    only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"],
+                    with_group_users=False,
+                ).values_list("id", flat=True),
+            )
+        else:  # groups
+            return list(
+                get_groups_with_only_permission(
+                    obj,
+                    codename=f"{codename}_{obj.__class__.__name__.lower()}",
+                ).values_list("id", flat=True),
+            )
+
     @extend_schema_field(
         field={
             "type": "object",
@@ -286,31 +316,14 @@ class OwnedObjectSerializer(
         },
     )
     def get_permissions(self, obj) -> dict:
-        view_codename = f"view_{obj.__class__.__name__.lower()}"
-        change_codename = f"change_{obj.__class__.__name__.lower()}"
-
         return {
             "view": {
-                "users": get_users_with_perms(
-                    obj,
-                    only_with_perms_in=[view_codename],
-                    with_group_users=False,
-                ).values_list("id", flat=True),
-                "groups": get_groups_with_only_permission(
-                    obj,
-                    codename=view_codename,
-                ).values_list("id", flat=True),
+                "users": self._get_perms(obj, "view", "users"),
+                "groups": self._get_perms(obj, "view", "groups"),
             },
             "change": {
-                "users": get_users_with_perms(
-                    obj,
-                    only_with_perms_in=[change_codename],
-                    with_group_users=False,
-                ).values_list("id", flat=True),
-                "groups": get_groups_with_only_permission(
-                    obj,
-                    codename=change_codename,
-                ).values_list("id", flat=True),
+                "users": self._get_perms(obj, "change", "users"),
+                "groups": self._get_perms(obj, "change", "groups"),
             },
         }
 
index 4c987e3af361dbdf2c4e7a2cc2ea6aabb544f558..57562c02cf4f5f22b7d4c3d5937049243078cf17 100644 (file)
@@ -1,17 +1,23 @@
+import json
 import tempfile
 from datetime import timedelta
 from pathlib import Path
 
 from django.conf import settings
+from django.contrib.auth.models import Group
 from django.contrib.auth.models import Permission
 from django.contrib.auth.models import User
+from django.db import connection
 from django.test import TestCase
 from django.test import override_settings
+from django.test.utils import CaptureQueriesContext
 from django.utils import timezone
+from guardian.shortcuts import assign_perm
 from rest_framework import status
 
 from documents.models import Document
 from documents.models import ShareLink
+from documents.models import Tag
 from documents.tests.utils import DirectoriesMixin
 from paperless.models import ApplicationConfiguration
 
@@ -154,3 +160,113 @@ class TestViews(DirectoriesMixin, TestCase):
         response.render()
         self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
         self.assertContains(response, b"Share link has expired")
+
+    def test_list_with_full_permissions(self):
+        """
+        GIVEN:
+            - Tags with different permissions
+        WHEN:
+            - Request to get tag list with full permissions is made
+        THEN:
+            - Tag list is returned with the right permission information
+        """
+        user2 = User.objects.create(username="user2")
+        user3 = User.objects.create(username="user3")
+        group1 = Group.objects.create(name="group1")
+        group2 = Group.objects.create(name="group2")
+        group3 = Group.objects.create(name="group3")
+        t1 = Tag.objects.create(name="invoice", pk=1)
+        assign_perm("view_tag", self.user, t1)
+        assign_perm("view_tag", user2, t1)
+        assign_perm("view_tag", user3, t1)
+        assign_perm("view_tag", group1, t1)
+        assign_perm("view_tag", group2, t1)
+        assign_perm("view_tag", group3, t1)
+        assign_perm("change_tag", self.user, t1)
+        assign_perm("change_tag", user2, t1)
+        assign_perm("change_tag", group1, t1)
+        assign_perm("change_tag", group2, t1)
+
+        Tag.objects.create(name="bank statement", pk=2)
+        d1 = Document.objects.create(
+            title="Invoice 1",
+            content="This is the invoice of a very expensive item",
+            checksum="A",
+        )
+        d1.tags.add(t1)
+        d2 = Document.objects.create(
+            title="Invoice 2",
+            content="Internet invoice, I should pay it to continue contributing",
+            checksum="B",
+        )
+        d2.tags.add(t1)
+
+        view_permissions = Permission.objects.filter(
+            codename__contains="view_tag",
+        )
+        self.user.user_permissions.add(*view_permissions)
+        self.user.save()
+
+        self.client.force_login(self.user)
+        response = self.client.get("/api/tags/?page=1&full_perms=true")
+        results = json.loads(response.content)["results"]
+        for tag in results:
+            if tag["name"] == "invoice":
+                assert tag["permissions"] == {
+                    "view": {
+                        "users": [self.user.pk, user2.pk, user3.pk],
+                        "groups": [group1.pk, group2.pk, group3.pk],
+                    },
+                    "change": {
+                        "users": [self.user.pk, user2.pk],
+                        "groups": [group1.pk, group2.pk],
+                    },
+                }
+            elif tag["name"] == "bank statement":
+                assert tag["permissions"] == {
+                    "view": {"users": [], "groups": []},
+                    "change": {"users": [], "groups": []},
+                }
+            else:
+                assert False, f"Unexpected tag found: {tag['name']}"
+
+    def test_list_no_n_plus_1_queries(self):
+        """
+        GIVEN:
+            - Tags with different permissions
+        WHEN:
+            - Request to get tag list with full permissions is made
+        THEN:
+            - Permissions are not queries in database tag by tag,
+             i.e. there are no N+1 queries
+        """
+        view_permissions = Permission.objects.filter(
+            codename__contains="view_tag",
+        )
+        self.user.user_permissions.add(*view_permissions)
+        self.user.save()
+        self.client.force_login(self.user)
+
+        # Start by a small list, and count the number of SQL queries
+        for i in range(2):
+            Tag.objects.create(name=f"tag_{i}")
+
+        with CaptureQueriesContext(connection) as ctx_small:
+            response_small = self.client.get("/api/tags/?full_perms=true")
+            assert response_small.status_code == 200
+        num_queries_small = len(ctx_small.captured_queries)
+
+        # Complete the list, and count the number of SQL queries again
+        for i in range(2, 50):
+            Tag.objects.create(name=f"tag_{i}")
+
+        with CaptureQueriesContext(connection) as ctx_large:
+            response_large = self.client.get("/api/tags/?full_perms=true")
+            assert response_large.status_code == 200
+        num_queries_large = len(ctx_large.captured_queries)
+
+        # A few additional queries are allowed, but not a linear explosion
+        assert num_queries_large <= num_queries_small + 5, (
+            f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
+            f"but {num_queries_large} queries for 50 tags"
+        )
index 86eab92e384c205b8cf428a74f35be1c2307f054..bce7428cd04ed620253962426c4aa9fcbaf9efd6 100644 (file)
@@ -5,9 +5,11 @@ import platform
 import re
 import tempfile
 import zipfile
+from collections import defaultdict
 from datetime import datetime
 from pathlib import Path
 from time import mktime
+from typing import Literal
 from unicodedata import normalize
 from urllib.parse import quote
 from urllib.parse import urlparse
@@ -19,6 +21,7 @@ from celery import states
 from django.conf import settings
 from django.contrib.auth.models import Group
 from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
 from django.db import connections
 from django.db.migrations.loader import MigrationLoader
 from django.db.migrations.recorder import MigrationRecorder
@@ -56,6 +59,8 @@ from drf_spectacular.utils import OpenApiParameter
 from drf_spectacular.utils import extend_schema
 from drf_spectacular.utils import extend_schema_view
 from drf_spectacular.utils import inline_serializer
+from guardian.utils import get_group_obj_perms_model
+from guardian.utils import get_user_obj_perms_model
 from langdetect import detect
 from packaging import version as packaging_version
 from redis import Redis
@@ -254,7 +259,101 @@ class PassUserMixin(GenericAPIView):
         return super().get_serializer(*args, **kwargs)
 
 
-class PermissionsAwareDocumentCountMixin(PassUserMixin):
+class BulkPermissionMixin:
+    """
+    Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries.
+    """
+
+    def get_permission_codenames(self):
+        model_name = self.queryset.model.__name__.lower()
+        return {
+            "view": f"view_{model_name}",
+            "change": f"change_{model_name}",
+        }
+
+    def _get_object_perms(
+        self,
+        objects: list,
+        perm_codenames: list[str],
+        actor: Literal["users", "groups"],
+    ) -> dict[int, dict[str, list[int]]]:
+        """
+        Collect object-level permissions for either users or groups.
+        """
+        model = self.queryset.model
+        obj_perm_model = (
+            get_user_obj_perms_model(model)
+            if actor == "users"
+            else get_group_obj_perms_model(model)
+        )
+        id_field = "user_id" if actor == "users" else "group_id"
+        ctype = ContentType.objects.get_for_model(model)
+        object_pks = [obj.pk for obj in objects]
+
+        perms_qs = obj_perm_model.objects.filter(
+            content_type=ctype,
+            object_pk__in=object_pks,
+            permission__codename__in=perm_codenames,
+        ).values_list("object_pk", id_field, "permission__codename")
+
+        perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list))
+        for object_pk, actor_id, codename in perms_qs:
+            perms[int(object_pk)][codename].append(actor_id)
+
+        # Ensure that all objects have all codenames, even if empty
+        for pk in object_pks:
+            for codename in perm_codenames:
+                perms[pk][codename]
+
+        return perms
+
+    def get_serializer_context(self):
+        """
+        Get all permissions of the current list of objects at once and pass them to the serializer.
+        This avoid fetching permissions object by object in database.
+        """
+        context = super().get_serializer_context()
+        try:
+            full_perms = get_boolean(
+                str(self.request.query_params.get("full_perms", "false")),
+            )
+        except ValueError:
+            full_perms = False
+
+        if not full_perms:
+            return context
+
+        # Check which objects are being paginated
+        page = getattr(self, "paginator", None)
+        if page and hasattr(page, "page"):
+            queryset = page.page.object_list
+        elif hasattr(self, "page"):
+            queryset = self.page
+        else:
+            queryset = self.filter_queryset(self.get_queryset())
+
+        codenames = self.get_permission_codenames()
+        perm_names = [codenames["view"], codenames["change"]]
+        user_perms = self._get_object_perms(queryset, perm_names, actor="users")
+        group_perms = self._get_object_perms(queryset, perm_names, actor="groups")
+
+        context["users_view_perms"] = {
+            pk: user_perms[pk][codenames["view"]] for pk in user_perms
+        }
+        context["users_change_perms"] = {
+            pk: user_perms[pk][codenames["change"]] for pk in user_perms
+        }
+        context["groups_view_perms"] = {
+            pk: group_perms[pk][codenames["view"]] for pk in group_perms
+        }
+        context["groups_change_perms"] = {
+            pk: group_perms[pk][codenames["change"]] for pk in group_perms
+        }
+
+        return context
+
+
+class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
     """
     Mixin to add document count to queryset, permissions-aware if needed
     """