from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING
+from typing import Literal
import magic
from celery import states
except KeyError:
pass
+ def _get_perms(self, obj, codename: str, target: Literal["users", "groups"]):
+ """
+ Get the given permissions from context or from django-guardian.
+
+ :param codename: The permission codename, e.g. 'view' or 'change'
+ :param target: 'users' or 'groups'
+ """
+ key = f"{target}_{codename}_perms"
+ cached = self.context.get(key, {}).get(obj.pk)
+ if cached is not None:
+ return list(cached)
+
+ # Permission not found in the context, get it from guardian
+ if target == "users":
+ return list(
+ get_users_with_perms(
+ obj,
+ only_with_perms_in=[f"{codename}_{obj.__class__.__name__.lower()}"],
+ with_group_users=False,
+ ).values_list("id", flat=True),
+ )
+ else: # groups
+ return list(
+ get_groups_with_only_permission(
+ obj,
+ codename=f"{codename}_{obj.__class__.__name__.lower()}",
+ ).values_list("id", flat=True),
+ )
+
@extend_schema_field(
field={
"type": "object",
},
)
def get_permissions(self, obj) -> dict:
- view_codename = f"view_{obj.__class__.__name__.lower()}"
- change_codename = f"change_{obj.__class__.__name__.lower()}"
-
return {
"view": {
- "users": get_users_with_perms(
- obj,
- only_with_perms_in=[view_codename],
- with_group_users=False,
- ).values_list("id", flat=True),
- "groups": get_groups_with_only_permission(
- obj,
- codename=view_codename,
- ).values_list("id", flat=True),
+ "users": self._get_perms(obj, "view", "users"),
+ "groups": self._get_perms(obj, "view", "groups"),
},
"change": {
- "users": get_users_with_perms(
- obj,
- only_with_perms_in=[change_codename],
- with_group_users=False,
- ).values_list("id", flat=True),
- "groups": get_groups_with_only_permission(
- obj,
- codename=change_codename,
- ).values_list("id", flat=True),
+ "users": self._get_perms(obj, "change", "users"),
+ "groups": self._get_perms(obj, "change", "groups"),
},
}
+import json
import tempfile
from datetime import timedelta
from pathlib import Path
from django.conf import settings
+from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
+from django.db import connection
from django.test import TestCase
from django.test import override_settings
+from django.test.utils import CaptureQueriesContext
from django.utils import timezone
+from guardian.shortcuts import assign_perm
from rest_framework import status
from documents.models import Document
from documents.models import ShareLink
+from documents.models import Tag
from documents.tests.utils import DirectoriesMixin
from paperless.models import ApplicationConfiguration
response.render()
self.assertEqual(response.request["PATH_INFO"], "/accounts/login/")
self.assertContains(response, b"Share link has expired")
+
+ def test_list_with_full_permissions(self):
+ """
+ GIVEN:
+ - Tags with different permissions
+ WHEN:
+ - Request to get tag list with full permissions is made
+ THEN:
+ - Tag list is returned with the right permission information
+ """
+ user2 = User.objects.create(username="user2")
+ user3 = User.objects.create(username="user3")
+ group1 = Group.objects.create(name="group1")
+ group2 = Group.objects.create(name="group2")
+ group3 = Group.objects.create(name="group3")
+ t1 = Tag.objects.create(name="invoice", pk=1)
+ assign_perm("view_tag", self.user, t1)
+ assign_perm("view_tag", user2, t1)
+ assign_perm("view_tag", user3, t1)
+ assign_perm("view_tag", group1, t1)
+ assign_perm("view_tag", group2, t1)
+ assign_perm("view_tag", group3, t1)
+ assign_perm("change_tag", self.user, t1)
+ assign_perm("change_tag", user2, t1)
+ assign_perm("change_tag", group1, t1)
+ assign_perm("change_tag", group2, t1)
+
+ Tag.objects.create(name="bank statement", pk=2)
+ d1 = Document.objects.create(
+ title="Invoice 1",
+ content="This is the invoice of a very expensive item",
+ checksum="A",
+ )
+ d1.tags.add(t1)
+ d2 = Document.objects.create(
+ title="Invoice 2",
+ content="Internet invoice, I should pay it to continue contributing",
+ checksum="B",
+ )
+ d2.tags.add(t1)
+
+ view_permissions = Permission.objects.filter(
+ codename__contains="view_tag",
+ )
+ self.user.user_permissions.add(*view_permissions)
+ self.user.save()
+
+ self.client.force_login(self.user)
+ response = self.client.get("/api/tags/?page=1&full_perms=true")
+ results = json.loads(response.content)["results"]
+ for tag in results:
+ if tag["name"] == "invoice":
+ assert tag["permissions"] == {
+ "view": {
+ "users": [self.user.pk, user2.pk, user3.pk],
+ "groups": [group1.pk, group2.pk, group3.pk],
+ },
+ "change": {
+ "users": [self.user.pk, user2.pk],
+ "groups": [group1.pk, group2.pk],
+ },
+ }
+ elif tag["name"] == "bank statement":
+ assert tag["permissions"] == {
+ "view": {"users": [], "groups": []},
+ "change": {"users": [], "groups": []},
+ }
+ else:
+ assert False, f"Unexpected tag found: {tag['name']}"
+
+ def test_list_no_n_plus_1_queries(self):
+ """
+ GIVEN:
+ - Tags with different permissions
+ WHEN:
+ - Request to get tag list with full permissions is made
+ THEN:
+ - Permissions are not queries in database tag by tag,
+ i.e. there are no N+1 queries
+ """
+ view_permissions = Permission.objects.filter(
+ codename__contains="view_tag",
+ )
+ self.user.user_permissions.add(*view_permissions)
+ self.user.save()
+ self.client.force_login(self.user)
+
+ # Start by a small list, and count the number of SQL queries
+ for i in range(2):
+ Tag.objects.create(name=f"tag_{i}")
+
+ with CaptureQueriesContext(connection) as ctx_small:
+ response_small = self.client.get("/api/tags/?full_perms=true")
+ assert response_small.status_code == 200
+ num_queries_small = len(ctx_small.captured_queries)
+
+ # Complete the list, and count the number of SQL queries again
+ for i in range(2, 50):
+ Tag.objects.create(name=f"tag_{i}")
+
+ with CaptureQueriesContext(connection) as ctx_large:
+ response_large = self.client.get("/api/tags/?full_perms=true")
+ assert response_large.status_code == 200
+ num_queries_large = len(ctx_large.captured_queries)
+
+ # A few additional queries are allowed, but not a linear explosion
+ assert num_queries_large <= num_queries_small + 5, (
+ f"Possible N+1 queries detected: {num_queries_small} queries for 2 tags, "
+ f"but {num_queries_large} queries for 50 tags"
+ )
import re
import tempfile
import zipfile
+from collections import defaultdict
from datetime import datetime
from pathlib import Path
from time import mktime
+from typing import Literal
from unicodedata import normalize
from urllib.parse import quote
from urllib.parse import urlparse
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
from django.db import connections
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import extend_schema_view
from drf_spectacular.utils import inline_serializer
+from guardian.utils import get_group_obj_perms_model
+from guardian.utils import get_user_obj_perms_model
from langdetect import detect
from packaging import version as packaging_version
from redis import Redis
return super().get_serializer(*args, **kwargs)
-class PermissionsAwareDocumentCountMixin(PassUserMixin):
+class BulkPermissionMixin:
+ """
+ Prefetch Django-Guardian permissions for a list before serialization, to avoid N+1 queries.
+ """
+
+ def get_permission_codenames(self):
+ model_name = self.queryset.model.__name__.lower()
+ return {
+ "view": f"view_{model_name}",
+ "change": f"change_{model_name}",
+ }
+
+ def _get_object_perms(
+ self,
+ objects: list,
+ perm_codenames: list[str],
+ actor: Literal["users", "groups"],
+ ) -> dict[int, dict[str, list[int]]]:
+ """
+ Collect object-level permissions for either users or groups.
+ """
+ model = self.queryset.model
+ obj_perm_model = (
+ get_user_obj_perms_model(model)
+ if actor == "users"
+ else get_group_obj_perms_model(model)
+ )
+ id_field = "user_id" if actor == "users" else "group_id"
+ ctype = ContentType.objects.get_for_model(model)
+ object_pks = [obj.pk for obj in objects]
+
+ perms_qs = obj_perm_model.objects.filter(
+ content_type=ctype,
+ object_pk__in=object_pks,
+ permission__codename__in=perm_codenames,
+ ).values_list("object_pk", id_field, "permission__codename")
+
+ perms: dict[int, dict[str, list[int]]] = defaultdict(lambda: defaultdict(list))
+ for object_pk, actor_id, codename in perms_qs:
+ perms[int(object_pk)][codename].append(actor_id)
+
+ # Ensure that all objects have all codenames, even if empty
+ for pk in object_pks:
+ for codename in perm_codenames:
+ perms[pk][codename]
+
+ return perms
+
+ def get_serializer_context(self):
+ """
+ Get all permissions of the current list of objects at once and pass them to the serializer.
+ This avoid fetching permissions object by object in database.
+ """
+ context = super().get_serializer_context()
+ try:
+ full_perms = get_boolean(
+ str(self.request.query_params.get("full_perms", "false")),
+ )
+ except ValueError:
+ full_perms = False
+
+ if not full_perms:
+ return context
+
+ # Check which objects are being paginated
+ page = getattr(self, "paginator", None)
+ if page and hasattr(page, "page"):
+ queryset = page.page.object_list
+ elif hasattr(self, "page"):
+ queryset = self.page
+ else:
+ queryset = self.filter_queryset(self.get_queryset())
+
+ codenames = self.get_permission_codenames()
+ perm_names = [codenames["view"], codenames["change"]]
+ user_perms = self._get_object_perms(queryset, perm_names, actor="users")
+ group_perms = self._get_object_perms(queryset, perm_names, actor="groups")
+
+ context["users_view_perms"] = {
+ pk: user_perms[pk][codenames["view"]] for pk in user_perms
+ }
+ context["users_change_perms"] = {
+ pk: user_perms[pk][codenames["change"]] for pk in user_perms
+ }
+ context["groups_view_perms"] = {
+ pk: group_perms[pk][codenames["view"]] for pk in group_perms
+ }
+ context["groups_change_perms"] = {
+ pk: group_perms[pk][codenames["change"]] for pk in group_perms
+ }
+
+ return context
+
+
+class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin):
"""
Mixin to add document count to queryset, permissions-aware if needed
"""