+++ /dev/null
-from __future__ import annotations
-
-import functools
-import inspect
-import json
-import operator
-from contextlib import contextmanager
-from typing import TYPE_CHECKING
-
-from django.contrib.contenttypes.models import ContentType
-from django.db.models import Case
-from django.db.models import CharField
-from django.db.models import Count
-from django.db.models import Exists
-from django.db.models import IntegerField
-from django.db.models import OuterRef
-from django.db.models import Q
-from django.db.models import Subquery
-from django.db.models import Sum
-from django.db.models import Value
-from django.db.models import When
-from django.db.models.functions import Cast
-from django.utils.translation import gettext_lazy as _
-from django_filters.rest_framework import BooleanFilter
-from django_filters.rest_framework import Filter
-from django_filters.rest_framework import FilterSet
-from drf_spectacular.utils import extend_schema_field
-from guardian.utils import get_group_obj_perms_model
-from guardian.utils import get_user_obj_perms_model
-from rest_framework import serializers
-from rest_framework.filters import OrderingFilter
-from rest_framework_guardian.filters import ObjectPermissionsFilter
-
-from paperless.models import Correspondent
-from paperless.models import CustomField
-from paperless.models import CustomFieldInstance
-from paperless.models import Document
-from paperless.models import DocumentType
-from paperless.models import PaperlessTask
-from paperless.models import ShareLink
-from paperless.models import StoragePath
-from paperless.models import Tag
-
-if TYPE_CHECKING:
- from collections.abc import Callable
-
-CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
-ID_KWARGS = ["in", "exact"]
-INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
-DATE_KWARGS = [
- "year",
- "month",
- "day",
- "date__gt",
- "date__gte",
- "gt",
- "gte",
- "date__lt",
- "date__lte",
- "lt",
- "lte",
-]
-
-CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
-CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
-
-
-class CorrespondentFilterSet(FilterSet):
- class Meta:
- model = Correspondent
- fields = {
- "id": ID_KWARGS,
- "name": CHAR_KWARGS,
- }
-
-
-class TagFilterSet(FilterSet):
- class Meta:
- model = Tag
- fields = {
- "id": ID_KWARGS,
- "name": CHAR_KWARGS,
- }
-
-
-class DocumentTypeFilterSet(FilterSet):
- class Meta:
- model = DocumentType
- fields = {
- "id": ID_KWARGS,
- "name": CHAR_KWARGS,
- }
-
-
-class StoragePathFilterSet(FilterSet):
- class Meta:
- model = StoragePath
- fields = {
- "id": ID_KWARGS,
- "name": CHAR_KWARGS,
- "path": CHAR_KWARGS,
- }
-
-
-class ObjectFilter(Filter):
- def __init__(self, *, exclude=False, in_list=False, field_name=""):
- super().__init__()
- self.exclude = exclude
- self.in_list = in_list
- self.field_name = field_name
-
- def filter(self, qs, value):
- if not value:
- return qs
-
- try:
- object_ids = [int(x) for x in value.split(",")]
- except ValueError:
- return qs
-
- if self.in_list:
- qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
- else:
- for obj_id in object_ids:
- if self.exclude:
- qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
- else:
- qs = qs.filter(**{f"{self.field_name}__id": obj_id})
-
- return qs
-
-
-@extend_schema_field(serializers.BooleanField)
-class InboxFilter(Filter):
- def filter(self, qs, value):
- if value == "true":
- return qs.filter(tags__is_inbox_tag=True)
- elif value == "false":
- return qs.exclude(tags__is_inbox_tag=True)
- else:
- return qs
-
-
-@extend_schema_field(serializers.CharField)
-class TitleContentFilter(Filter):
- def filter(self, qs, value):
- if value:
- return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
- else:
- return qs
-
-
-@extend_schema_field(serializers.BooleanField)
-class SharedByUser(Filter):
- def filter(self, qs, value):
- ctype = ContentType.objects.get_for_model(self.model)
- UserObjectPermission = get_user_obj_perms_model()
- GroupObjectPermission = get_group_obj_perms_model()
- # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
- # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
- return (
- qs.filter(
- owner_id=value,
- )
- .annotate(
- num_shared_users=Count(
- UserObjectPermission.objects.filter(
- content_type=ctype,
- object_pk=Cast(OuterRef("pk"), CharField()),
- ).values("user_id")[:1],
- ),
- )
- .annotate(
- num_shared_groups=Count(
- GroupObjectPermission.objects.filter(
- content_type=ctype,
- object_pk=Cast(OuterRef("pk"), CharField()),
- ).values("group_id")[:1],
- ),
- )
- .filter(
- Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
- )
- if value is not None
- else qs
- )
-
-
-class CustomFieldFilterSet(FilterSet):
- class Meta:
- model = CustomField
- fields = {
- "id": ID_KWARGS,
- "name": CHAR_KWARGS,
- }
-
-
-@extend_schema_field(serializers.CharField)
-class CustomFieldsFilter(Filter):
- def filter(self, qs, value):
- if value:
- fields_with_matching_selects = CustomField.objects.filter(
- extra_data__icontains=value,
- )
- option_ids = []
- if fields_with_matching_selects.count() > 0:
- for field in fields_with_matching_selects:
- options = field.extra_data.get("select_options", [])
- for _, option in enumerate(options):
- if option.get("label").lower().find(value.lower()) != -1:
- option_ids.extend([option.get("id")])
- return (
- qs.filter(custom_fields__field__name__icontains=value)
- | qs.filter(custom_fields__value_text__icontains=value)
- | qs.filter(custom_fields__value_bool__icontains=value)
- | qs.filter(custom_fields__value_int__icontains=value)
- | qs.filter(custom_fields__value_float__icontains=value)
- | qs.filter(custom_fields__value_date__icontains=value)
- | qs.filter(custom_fields__value_url__icontains=value)
- | qs.filter(custom_fields__value_monetary__icontains=value)
- | qs.filter(custom_fields__value_document_ids__icontains=value)
- | qs.filter(custom_fields__value_select__in=option_ids)
- )
- else:
- return qs
-
-
-class MimeTypeFilter(Filter):
- def filter(self, qs, value):
- if value:
- return qs.filter(mime_type__icontains=value)
- else:
- return qs
-
-
-class SelectField(serializers.CharField):
- def __init__(self, custom_field: CustomField):
- self._options = custom_field.extra_data["select_options"]
- super().__init__(max_length=16)
-
- def to_internal_value(self, data):
- # If the supplied value is the option label instead of the ID
- try:
- data = next(
- option.get("id")
- for option in self._options
- if option.get("label") == data
- )
- except StopIteration:
- pass
- return super().to_internal_value(data)
-
-
-def handle_validation_prefix(func: Callable):
- """
- Catch ValidationErrors raised by the wrapped function
- and add a prefix to the exception detail to track what causes the exception,
- similar to nested serializers.
- """
-
- def wrapper(*args, validation_prefix=None, **kwargs):
- try:
- return func(*args, **kwargs)
- except serializers.ValidationError as e:
- raise serializers.ValidationError({validation_prefix: e.detail})
-
- # Update the signature to include the validation_prefix argument
- old_sig = inspect.signature(func)
- new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
- new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
-
- # Apply functools.wraps and manually set the new signature
- functools.update_wrapper(wrapper, func)
- wrapper.__signature__ = new_sig
-
- return wrapper
-
-
-class CustomFieldQueryParser:
- EXPR_BY_CATEGORY = {
- "basic": ["exact", "in", "isnull", "exists"],
- "string": [
- "icontains",
- "istartswith",
- "iendswith",
- ],
- "arithmetic": [
- "gt",
- "gte",
- "lt",
- "lte",
- "range",
- ],
- "containment": ["contains"],
- }
-
- SUPPORTED_EXPR_CATEGORIES = {
- CustomField.FieldDataType.STRING: ("basic", "string"),
- CustomField.FieldDataType.URL: ("basic", "string"),
- CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
- CustomField.FieldDataType.BOOL: ("basic",),
- CustomField.FieldDataType.INT: ("basic", "arithmetic"),
- CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
- CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
- CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
- CustomField.FieldDataType.SELECT: ("basic",),
- }
-
- DATE_COMPONENTS = [
- "year",
- "iso_year",
- "month",
- "day",
- "week",
- "week_day",
- "iso_week_day",
- "quarter",
- ]
-
- def __init__(
- self,
- validation_prefix,
- max_query_depth=10,
- max_atom_count=20,
- ) -> None:
- """
- A helper class that parses the query string into a `django.db.models.Q` for filtering
- documents based on custom field values.
-
- The syntax of the query expression is illustrated with the below pseudo code rules:
- 1. parse([`custom_field`, "exists", true]):
- matches documents with Q(custom_fields__field=`custom_field`)
- 2. parse([`custom_field`, "exists", false]):
- matches documents with ~Q(custom_fields__field=`custom_field`)
- 3. parse([`custom_field`, `op`, `value`]):
- matches documents with
- Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
- 4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
- -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
- 5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
- -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
- 6. parse(["NOT", `q`])
- -> ~parse(`q`)
-
- Args:
- validation_prefix: Used to generate the ValidationError message.
- max_query_depth: Limits the maximum nesting depth of queries.
- max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
-
- `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
- complex SQL queries.
- """
- self._custom_fields: dict[int | str, CustomField] = {}
- self._validation_prefix = validation_prefix
- # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
- self._model_serializer = serializers.ModelSerializer()
- # Used for sanity check
- self._max_query_depth = max_query_depth
- self._max_atom_count = max_atom_count
- self._current_depth = 0
- self._atom_count = 0
- # The set of annotations that we need to apply to the queryset
- self._annotations = {}
-
- def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
- """
- Parses the query string into a `django.db.models.Q`
- and a set of annotations to be applied to the queryset.
- """
- try:
- expr = json.loads(query)
- except json.JSONDecodeError:
- raise serializers.ValidationError(
- {self._validation_prefix: [_("Value must be valid JSON.")]},
- )
- return (
- self._parse_expr(expr, validation_prefix=self._validation_prefix),
- self._annotations,
- )
-
- @handle_validation_prefix
- def _parse_expr(self, expr) -> Q:
- """
- Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
- """
- with self._track_query_depth():
- if isinstance(expr, list | tuple):
- if len(expr) == 2:
- return self._parse_logical_expr(*expr)
- elif len(expr) == 3:
- return self._parse_atom(*expr)
- raise serializers.ValidationError(
- [_("Invalid custom field query expression")],
- )
-
- @handle_validation_prefix
- def _parse_expr_list(self, exprs) -> list[Q]:
- """
- Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
- """
- if not isinstance(exprs, list | tuple) or not exprs:
- raise serializers.ValidationError(
- [_("Invalid expression list. Must be nonempty.")],
- )
- return [
- self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
- ]
-
- def _parse_logical_expr(self, op, args) -> Q:
- """
- Handles rule 4, 5, 6.
- """
- op_lower = op.lower()
-
- if op_lower == "not":
- return ~self._parse_expr(args, validation_prefix=1)
-
- if op_lower == "and":
- op_func = operator.and_
- elif op_lower == "or":
- op_func = operator.or_
- else:
- raise serializers.ValidationError(
- {"0": [_("Invalid logical operator {op!r}").format(op=op)]},
- )
-
- qs = self._parse_expr_list(args, validation_prefix="1")
- return functools.reduce(op_func, qs)
-
- def _parse_atom(self, id_or_name, op, value) -> Q:
- """
- Handles rule 1, 2, 3.
- """
- # Guard against queries with too many conditions.
- self._atom_count += 1
- if self._atom_count > self._max_atom_count:
- raise serializers.ValidationError(
- [_("Maximum number of query conditions exceeded.")],
- )
-
- custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
- op = self._validate_atom_op(custom_field, op, validation_prefix="1")
- value = self._validate_atom_value(
- custom_field,
- op,
- value,
- validation_prefix="2",
- )
-
- # Needed because not all DB backends support Array __contains
- if (
- custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
- and op == "contains"
- ):
- return self._parse_atom_doc_link_contains(custom_field, value)
-
- value_field_name = CustomFieldInstance.get_value_field_name(
- custom_field.data_type,
- )
- if (
- custom_field.data_type == CustomField.FieldDataType.MONETARY
- and op in self.EXPR_BY_CATEGORY["arithmetic"]
- ):
- value_field_name = "value_monetary_amount"
- has_field = Q(custom_fields__field=custom_field)
-
- # We need to use an annotation here because different atoms
- # might be referring to different instances of custom fields.
- annotation_name = f"_custom_field_filter_{len(self._annotations)}"
-
- # Our special exists operator.
- if op == "exists":
- annotation = Count("custom_fields", filter=has_field)
- # A Document should have > 0 match if it has this field, or 0 if doesn't.
- query_op = "gt" if value else "exact"
- query = Q(**{f"{annotation_name}__{query_op}": 0})
- else:
- # Check if 1) custom field name matches, and 2) value satisfies condition
- field_filter = has_field & Q(
- **{f"custom_fields__{value_field_name}__{op}": value},
- )
- # Annotate how many matching custom fields each document has
- annotation = Count("custom_fields", filter=field_filter)
- # Filter document by count
- query = Q(**{f"{annotation_name}__gt": 0})
-
- self._annotations[annotation_name] = annotation
- return query
-
- @handle_validation_prefix
- def _get_custom_field(self, id_or_name):
- """Get the CustomField instance by id or name."""
- if id_or_name in self._custom_fields:
- return self._custom_fields[id_or_name]
-
- kwargs = (
- {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
- )
- try:
- custom_field = CustomField.objects.get(**kwargs)
- except CustomField.DoesNotExist:
- raise serializers.ValidationError(
- [_("{name!r} is not a valid custom field.").format(name=id_or_name)],
- )
- self._custom_fields[custom_field.id] = custom_field
- self._custom_fields[custom_field.name] = custom_field
- return custom_field
-
- @staticmethod
- def _split_op(full_op):
- *prefix, op = str(full_op).rsplit("__", maxsplit=1)
- prefix = prefix[0] if prefix else None
- return prefix, op
-
- @handle_validation_prefix
- def _validate_atom_op(self, custom_field, raw_op):
- """Check if the `op` is compatible with the type of the custom field."""
- prefix, op = self._split_op(raw_op)
-
- # Check if the operator is supported for the current data_type.
- supported = False
- for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
- if op in self.EXPR_BY_CATEGORY[category]:
- supported = True
- break
-
- # Check prefix
- if prefix is not None:
- if (
- prefix in self.DATE_COMPONENTS
- and custom_field.data_type == CustomField.FieldDataType.DATE
- ):
- pass # ok - e.g., "year__exact" for date field
- else:
- supported = False # anything else is invalid
-
- if not supported:
- raise serializers.ValidationError(
- [
- _("{data_type} does not support query expr {expr!r}.").format(
- data_type=custom_field.data_type,
- expr=raw_op,
- ),
- ],
- )
-
- return raw_op
-
- def _get_serializer_field(self, custom_field, full_op):
- """Return a serializers.Field for value validation."""
- prefix, op = self._split_op(full_op)
- field = None
-
- if op in ("isnull", "exists"):
- # `isnull` takes either True or False regardless of the data_type.
- field = serializers.BooleanField()
- elif (
- custom_field.data_type == CustomField.FieldDataType.DATE
- and prefix in self.DATE_COMPONENTS
- ):
- # DateField admits queries in the form of `year__exact`, etc. These take integers.
- field = serializers.IntegerField()
- elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
- # We can be more specific here and make sure the value is a list.
- field = serializers.ListField(child=serializers.IntegerField())
- elif custom_field.data_type == CustomField.FieldDataType.SELECT:
- # We use this custom field to permit SELECT option names.
- field = SelectField(custom_field)
- elif custom_field.data_type == CustomField.FieldDataType.URL:
- # For URL fields we don't need to be strict about validation (e.g., for istartswith).
- field = serializers.CharField()
- else:
- # The general case: inferred from the corresponding field in CustomFieldInstance.
- value_field_name = CustomFieldInstance.get_value_field_name(
- custom_field.data_type,
- )
- model_field = CustomFieldInstance._meta.get_field(value_field_name)
- field_name = model_field.deconstruct()[0]
- field_class, field_kwargs = self._model_serializer.build_standard_field(
- field_name,
- model_field,
- )
- field = field_class(**field_kwargs)
- field.allow_null = False
-
- # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
- # See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
- if isinstance(field, serializers.CharField):
- field.allow_blank = True
-
- if op == "in":
- # `in` takes a list of values.
- field = serializers.ListField(child=field, allow_empty=False)
- elif op == "range":
- # `range` takes a list of values, i.e., [start, end].
- field = serializers.ListField(
- child=field,
- min_length=2,
- max_length=2,
- )
-
- return field
-
- @handle_validation_prefix
- def _validate_atom_value(self, custom_field, op, value):
- """Check if `value` is valid for the custom field and `op`. Returns the validated value."""
- serializer_field = self._get_serializer_field(custom_field, op)
- return serializer_field.run_validation(value)
-
- def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
- """
- Handles document link `contains` in a way that is supported by all DB backends.
- """
-
- # If the value is an empty set,
- # this is trivially true for any document with not null document links.
- if not value:
- return Q(
- custom_fields__field=custom_field,
- custom_fields__value_document_ids__isnull=False,
- )
-
- # First we look up reverse links from the requested documents.
- links = CustomFieldInstance.objects.filter(
- document_id__in=value,
- field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
- )
-
- # Check if any of the requested IDs are missing.
- missing_ids = set(value) - set(link.document_id for link in links)
- if missing_ids:
- # The result should be an empty set in this case.
- return Q(id__in=[])
-
- # Take the intersection of the reverse links - this should be what we are looking for.
- document_ids_we_want = functools.reduce(
- operator.and_,
- (set(link.value_document_ids) for link in links),
- )
-
- return Q(id__in=document_ids_we_want)
-
- @contextmanager
- def _track_query_depth(self):
- # guard against queries that are too deeply nested
- self._current_depth += 1
- if self._current_depth > self._max_query_depth:
- raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
- try:
- yield
- finally:
- self._current_depth -= 1
-
-
-@extend_schema_field(serializers.CharField)
-class CustomFieldQueryFilter(Filter):
- def __init__(self, validation_prefix):
- """
- A filter that filters documents based on custom field name and value.
-
- Args:
- validation_prefix: Used to generate the ValidationError message.
- """
- super().__init__()
- self._validation_prefix = validation_prefix
-
- def filter(self, qs, value):
- if not value:
- return qs
-
- parser = CustomFieldQueryParser(
- self._validation_prefix,
- max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
- max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
- )
- q, annotations = parser.parse(value)
-
- return qs.annotate(**annotations).filter(q)
-
-
-class DocumentFilterSet(FilterSet):
- is_tagged = BooleanFilter(
- label="Is tagged",
- field_name="tags",
- lookup_expr="isnull",
- exclude=True,
- )
-
- tags__id__all = ObjectFilter(field_name="tags")
-
- tags__id__none = ObjectFilter(field_name="tags", exclude=True)
-
- tags__id__in = ObjectFilter(field_name="tags", in_list=True)
-
- correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
-
- document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
-
- storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
-
- is_in_inbox = InboxFilter()
-
- title_content = TitleContentFilter()
-
- owner__id__none = ObjectFilter(field_name="owner", exclude=True)
-
- custom_fields__icontains = CustomFieldsFilter()
-
- custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
-
- custom_fields__id__none = ObjectFilter(
- field_name="custom_fields__field",
- exclude=True,
- )
-
- custom_fields__id__in = ObjectFilter(
- field_name="custom_fields__field",
- in_list=True,
- )
-
- has_custom_fields = BooleanFilter(
- label="Has custom field",
- field_name="custom_fields",
- lookup_expr="isnull",
- exclude=True,
- )
-
- custom_field_query = CustomFieldQueryFilter("custom_field_query")
-
- shared_by__id = SharedByUser()
-
- mime_type = MimeTypeFilter()
-
- class Meta:
- model = Document
- fields = {
- "id": ID_KWARGS,
- "title": CHAR_KWARGS,
- "content": CHAR_KWARGS,
- "archive_serial_number": INT_KWARGS,
- "created": DATE_KWARGS,
- "added": DATE_KWARGS,
- "modified": DATE_KWARGS,
- "original_filename": CHAR_KWARGS,
- "checksum": CHAR_KWARGS,
- "correspondent": ["isnull"],
- "correspondent__id": ID_KWARGS,
- "correspondent__name": CHAR_KWARGS,
- "tags__id": ID_KWARGS,
- "tags__name": CHAR_KWARGS,
- "document_type": ["isnull"],
- "document_type__id": ID_KWARGS,
- "document_type__name": CHAR_KWARGS,
- "storage_path": ["isnull"],
- "storage_path__id": ID_KWARGS,
- "storage_path__name": CHAR_KWARGS,
- "owner": ["isnull"],
- "owner__id": ID_KWARGS,
- "custom_fields": ["icontains"],
- }
-
-
-class ShareLinkFilterSet(FilterSet):
- class Meta:
- model = ShareLink
- fields = {
- "created": DATE_KWARGS,
- "expiration": DATE_KWARGS,
- }
-
-
-class PaperlessTaskFilterSet(FilterSet):
- acknowledged = BooleanFilter(
- label="Acknowledged",
- field_name="acknowledged",
- )
-
- class Meta:
- model = PaperlessTask
- fields = {
- "type": ["exact"],
- "task_name": ["exact"],
- "status": ["exact"],
- }
-
-
-class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
- """
- A filter backend that limits results to those where the requesting user
- has read object level permissions, owns the objects, or objects without
- an owner (for backwards compat)
- """
-
- def filter_queryset(self, request, queryset, view):
- objects_with_perms = super().filter_queryset(request, queryset, view)
- objects_owned = queryset.filter(owner=request.user)
- objects_unowned = queryset.filter(owner__isnull=True)
- return objects_with_perms | objects_owned | objects_unowned
-
-
-class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
- """
- A filter backend that limits results to those where the requesting user
- owns the objects or objects without an owner (for backwards compat)
- """
-
- def filter_queryset(self, request, queryset, view):
- if request.user.is_superuser:
- return queryset
- objects_owned = queryset.filter(owner=request.user)
- objects_unowned = queryset.filter(owner__isnull=True)
- return objects_owned | objects_unowned
-
-
-class DocumentsOrderingFilter(OrderingFilter):
- field_name = "ordering"
- prefix = "custom_field_"
-
- def filter_queryset(self, request, queryset, view):
- param = request.query_params.get("ordering")
- if param and self.prefix in param:
- custom_field_id = int(param.split(self.prefix)[1])
- try:
- field = CustomField.objects.get(pk=custom_field_id)
- except CustomField.DoesNotExist:
- raise serializers.ValidationError(
- {self.prefix + str(custom_field_id): [_("Custom field not found")]},
- )
-
- annotation = None
- match field.data_type:
- case CustomField.FieldDataType.STRING:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_text")[:1],
- )
- case CustomField.FieldDataType.INT:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_int")[:1],
- )
- case CustomField.FieldDataType.FLOAT:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_float")[:1],
- )
- case CustomField.FieldDataType.DATE:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_date")[:1],
- )
- case CustomField.FieldDataType.MONETARY:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_monetary_amount")[:1],
- )
- case CustomField.FieldDataType.SELECT:
- # Select options are a little more complicated since the value is the id of the option, not
- # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
- # case statement for each option, setting the value to the index of the option in a list
- # sorted by label, and then summing the results to give a single value for the annotation
-
- select_options = sorted(
- field.extra_data.get("select_options", []),
- key=lambda x: x.get("label"),
- )
- whens = [
- When(
- custom_fields__field_id=custom_field_id,
- custom_fields__value_select=option.get("id"),
- then=Value(idx, output_field=IntegerField()),
- )
- for idx, option in enumerate(select_options)
- ]
- whens.append(
- When(
- custom_fields__field_id=custom_field_id,
- custom_fields__value_select__isnull=True,
- then=Value(
- len(select_options),
- output_field=IntegerField(),
- ),
- ),
- )
- annotation = Sum(
- Case(
- *whens,
- default=Value(0),
- output_field=IntegerField(),
- ),
- )
- case CustomField.FieldDataType.DOCUMENTLINK:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_document_ids")[:1],
- )
- case CustomField.FieldDataType.URL:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_url")[:1],
- )
- case CustomField.FieldDataType.BOOL:
- annotation = Subquery(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ).values("value_bool")[:1],
- )
-
- if not annotation:
- # Only happens if a new data type is added and not handled here
- raise ValueError("Invalid custom field data type")
-
- queryset = (
- queryset.annotate(
- # We need to annotate the queryset with the custom field value
- custom_field_value=annotation,
- # We also need to annotate the queryset with a boolean for sorting whether the field exists
- has_field=Exists(
- CustomFieldInstance.objects.filter(
- document_id=OuterRef("id"),
- field_id=custom_field_id,
- ),
- ),
- )
- .order_by(
- "-has_field",
- param.replace(
- self.prefix + str(custom_field_id),
- "custom_field_value",
- ),
- )
- .distinct()
- )
-
- return super().filter_queryset(request, queryset, view)
+from __future__ import annotations
+
+import functools
+import inspect
+import json
+import operator
+from contextlib import contextmanager
+from typing import TYPE_CHECKING
+
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
+from django.db.models import Case
+from django.db.models import CharField
+from django.db.models import Count
+from django.db.models import Exists
+from django.db.models import IntegerField
+from django.db.models import OuterRef
+from django.db.models import Q
+from django.db.models import Subquery
+from django.db.models import Sum
+from django.db.models import Value
+from django.db.models import When
+from django.db.models.functions import Cast
+from django.utils.translation import gettext_lazy as _
+from django_filters.rest_framework import BooleanFilter
+from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
+from drf_spectacular.utils import extend_schema_field
+from guardian.utils import get_group_obj_perms_model
+from guardian.utils import get_user_obj_perms_model
+from rest_framework import serializers
+from rest_framework.filters import OrderingFilter
+from rest_framework_guardian.filters import ObjectPermissionsFilter
+
+from paperless.models import Correspondent
+from paperless.models import CustomField
+from paperless.models import CustomFieldInstance
+from paperless.models import Document
+from paperless.models import DocumentType
+from paperless.models import PaperlessTask
+from paperless.models import ShareLink
+from paperless.models import StoragePath
+from paperless.models import Tag
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
+ID_KWARGS = ["in", "exact"]
+INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
+DATE_KWARGS = [
+ "year",
+ "month",
+ "day",
+ "date__gt",
+ "date__gte",
+ "gt",
+ "gte",
+ "date__lt",
+ "date__lte",
+ "lt",
+ "lte",
+]
-from documents.filters import CHAR_KWARGS
+CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
+CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
class UserFilterSet(FilterSet):
class Meta:
model = Group
fields = {"name": CHAR_KWARGS}
+
+
+class CorrespondentFilterSet(FilterSet):
+ class Meta:
+ model = Correspondent
+ fields = {
+ "id": ID_KWARGS,
+ "name": CHAR_KWARGS,
+ }
+
+
+class TagFilterSet(FilterSet):
+ class Meta:
+ model = Tag
+ fields = {
+ "id": ID_KWARGS,
+ "name": CHAR_KWARGS,
+ }
+
+
+class DocumentTypeFilterSet(FilterSet):
+ class Meta:
+ model = DocumentType
+ fields = {
+ "id": ID_KWARGS,
+ "name": CHAR_KWARGS,
+ }
+
+
+class StoragePathFilterSet(FilterSet):
+ class Meta:
+ model = StoragePath
+ fields = {
+ "id": ID_KWARGS,
+ "name": CHAR_KWARGS,
+ "path": CHAR_KWARGS,
+ }
+
+
+class ObjectFilter(Filter):
+ def __init__(self, *, exclude=False, in_list=False, field_name=""):
+ super().__init__()
+ self.exclude = exclude
+ self.in_list = in_list
+ self.field_name = field_name
+
+ def filter(self, qs, value):
+ if not value:
+ return qs
+
+ try:
+ object_ids = [int(x) for x in value.split(",")]
+ except ValueError:
+ return qs
+
+ if self.in_list:
+ qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
+ else:
+ for obj_id in object_ids:
+ if self.exclude:
+ qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
+ else:
+ qs = qs.filter(**{f"{self.field_name}__id": obj_id})
+
+ return qs
+
+
+@extend_schema_field(serializers.BooleanField)
+class InboxFilter(Filter):
+ def filter(self, qs, value):
+ if value == "true":
+ return qs.filter(tags__is_inbox_tag=True)
+ elif value == "false":
+ return qs.exclude(tags__is_inbox_tag=True)
+ else:
+ return qs
+
+
+@extend_schema_field(serializers.CharField)
+class TitleContentFilter(Filter):
+ def filter(self, qs, value):
+ if value:
+ return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
+ else:
+ return qs
+
+
+@extend_schema_field(serializers.BooleanField)
+class SharedByUser(Filter):
+ def filter(self, qs, value):
+ ctype = ContentType.objects.get_for_model(self.model)
+ UserObjectPermission = get_user_obj_perms_model()
+ GroupObjectPermission = get_group_obj_perms_model()
+ # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
+ # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
+ return (
+ qs.filter(
+ owner_id=value,
+ )
+ .annotate(
+ num_shared_users=Count(
+ UserObjectPermission.objects.filter(
+ content_type=ctype,
+ object_pk=Cast(OuterRef("pk"), CharField()),
+ ).values("user_id")[:1],
+ ),
+ )
+ .annotate(
+ num_shared_groups=Count(
+ GroupObjectPermission.objects.filter(
+ content_type=ctype,
+ object_pk=Cast(OuterRef("pk"), CharField()),
+ ).values("group_id")[:1],
+ ),
+ )
+ .filter(
+ Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
+ )
+ if value is not None
+ else qs
+ )
+
+
+class CustomFieldFilterSet(FilterSet):
+ class Meta:
+ model = CustomField
+ fields = {
+ "id": ID_KWARGS,
+ "name": CHAR_KWARGS,
+ }
+
+
+@extend_schema_field(serializers.CharField)
+class CustomFieldsFilter(Filter):
+ def filter(self, qs, value):
+ if value:
+ fields_with_matching_selects = CustomField.objects.filter(
+ extra_data__icontains=value,
+ )
+ option_ids = []
+ if fields_with_matching_selects.count() > 0:
+ for field in fields_with_matching_selects:
+ options = field.extra_data.get("select_options", [])
+ for _, option in enumerate(options):
+ if option.get("label").lower().find(value.lower()) != -1:
+ option_ids.extend([option.get("id")])
+ return (
+ qs.filter(custom_fields__field__name__icontains=value)
+ | qs.filter(custom_fields__value_text__icontains=value)
+ | qs.filter(custom_fields__value_bool__icontains=value)
+ | qs.filter(custom_fields__value_int__icontains=value)
+ | qs.filter(custom_fields__value_float__icontains=value)
+ | qs.filter(custom_fields__value_date__icontains=value)
+ | qs.filter(custom_fields__value_url__icontains=value)
+ | qs.filter(custom_fields__value_monetary__icontains=value)
+ | qs.filter(custom_fields__value_document_ids__icontains=value)
+ | qs.filter(custom_fields__value_select__in=option_ids)
+ )
+ else:
+ return qs
+
+
+class MimeTypeFilter(Filter):
+ def filter(self, qs, value):
+ if value:
+ return qs.filter(mime_type__icontains=value)
+ else:
+ return qs
+
+
+class SelectField(serializers.CharField):
+ def __init__(self, custom_field: CustomField):
+ self._options = custom_field.extra_data["select_options"]
+ super().__init__(max_length=16)
+
+ def to_internal_value(self, data):
+ # If the supplied value is the option label instead of the ID
+ try:
+ data = next(
+ option.get("id")
+ for option in self._options
+ if option.get("label") == data
+ )
+ except StopIteration:
+ pass
+ return super().to_internal_value(data)
+
+
+def handle_validation_prefix(func: Callable):
+ """
+ Catch ValidationErrors raised by the wrapped function
+ and add a prefix to the exception detail to track what causes the exception,
+ similar to nested serializers.
+ """
+
+ def wrapper(*args, validation_prefix=None, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except serializers.ValidationError as e:
+ raise serializers.ValidationError({validation_prefix: e.detail})
+
+ # Update the signature to include the validation_prefix argument
+ old_sig = inspect.signature(func)
+ new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
+ new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
+
+ # Apply functools.wraps and manually set the new signature
+ functools.update_wrapper(wrapper, func)
+ wrapper.__signature__ = new_sig
+
+ return wrapper
+
+
+class CustomFieldQueryParser:
+ EXPR_BY_CATEGORY = {
+ "basic": ["exact", "in", "isnull", "exists"],
+ "string": [
+ "icontains",
+ "istartswith",
+ "iendswith",
+ ],
+ "arithmetic": [
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "range",
+ ],
+ "containment": ["contains"],
+ }
+
+ SUPPORTED_EXPR_CATEGORIES = {
+ CustomField.FieldDataType.STRING: ("basic", "string"),
+ CustomField.FieldDataType.URL: ("basic", "string"),
+ CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
+ CustomField.FieldDataType.BOOL: ("basic",),
+ CustomField.FieldDataType.INT: ("basic", "arithmetic"),
+ CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
+ CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
+ CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
+ CustomField.FieldDataType.SELECT: ("basic",),
+ }
+
+ DATE_COMPONENTS = [
+ "year",
+ "iso_year",
+ "month",
+ "day",
+ "week",
+ "week_day",
+ "iso_week_day",
+ "quarter",
+ ]
+
+ def __init__(
+ self,
+ validation_prefix,
+ max_query_depth=10,
+ max_atom_count=20,
+ ) -> None:
+ """
+ A helper class that parses the query string into a `django.db.models.Q` for filtering
+ documents based on custom field values.
+
+ The syntax of the query expression is illustrated with the below pseudo code rules:
+ 1. parse([`custom_field`, "exists", true]):
+ matches documents with Q(custom_fields__field=`custom_field`)
+ 2. parse([`custom_field`, "exists", false]):
+ matches documents with ~Q(custom_fields__field=`custom_field`)
+ 3. parse([`custom_field`, `op`, `value`]):
+ matches documents with
+ Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
+ 4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
+ -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
+ 5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
+ -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
+ 6. parse(["NOT", `q`])
+ -> ~parse(`q`)
+
+ Args:
+ validation_prefix: Used to generate the ValidationError message.
+ max_query_depth: Limits the maximum nesting depth of queries.
+ max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
+
+ `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
+ complex SQL queries.
+ """
+ self._custom_fields: dict[int | str, CustomField] = {}
+ self._validation_prefix = validation_prefix
+ # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
+ self._model_serializer = serializers.ModelSerializer()
+ # Used for sanity check
+ self._max_query_depth = max_query_depth
+ self._max_atom_count = max_atom_count
+ self._current_depth = 0
+ self._atom_count = 0
+ # The set of annotations that we need to apply to the queryset
+ self._annotations = {}
+
+ def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
+ """
+ Parses the query string into a `django.db.models.Q`
+ and a set of annotations to be applied to the queryset.
+ """
+ try:
+ expr = json.loads(query)
+ except json.JSONDecodeError:
+ raise serializers.ValidationError(
+ {self._validation_prefix: [_("Value must be valid JSON.")]},
+ )
+ return (
+ self._parse_expr(expr, validation_prefix=self._validation_prefix),
+ self._annotations,
+ )
+
+ @handle_validation_prefix
+ def _parse_expr(self, expr) -> Q:
+ """
+ Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
+ """
+ with self._track_query_depth():
+ if isinstance(expr, list | tuple):
+ if len(expr) == 2:
+ return self._parse_logical_expr(*expr)
+ elif len(expr) == 3:
+ return self._parse_atom(*expr)
+ raise serializers.ValidationError(
+ [_("Invalid custom field query expression")],
+ )
+
+ @handle_validation_prefix
+ def _parse_expr_list(self, exprs) -> list[Q]:
+ """
+ Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
+ """
+ if not isinstance(exprs, list | tuple) or not exprs:
+ raise serializers.ValidationError(
+ [_("Invalid expression list. Must be nonempty.")],
+ )
+ return [
+ self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
+ ]
+
+ def _parse_logical_expr(self, op, args) -> Q:
+ """
+ Handles rule 4, 5, 6.
+ """
+ op_lower = op.lower()
+
+ if op_lower == "not":
+ return ~self._parse_expr(args, validation_prefix=1)
+
+ if op_lower == "and":
+ op_func = operator.and_
+ elif op_lower == "or":
+ op_func = operator.or_
+ else:
+ raise serializers.ValidationError(
+ {"0": [_("Invalid logical operator {op!r}").format(op=op)]},
+ )
+
+ qs = self._parse_expr_list(args, validation_prefix="1")
+ return functools.reduce(op_func, qs)
+
+ def _parse_atom(self, id_or_name, op, value) -> Q:
+ """
+ Handles rule 1, 2, 3.
+ """
+ # Guard against queries with too many conditions.
+ self._atom_count += 1
+ if self._atom_count > self._max_atom_count:
+ raise serializers.ValidationError(
+ [_("Maximum number of query conditions exceeded.")],
+ )
+
+ custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
+ op = self._validate_atom_op(custom_field, op, validation_prefix="1")
+ value = self._validate_atom_value(
+ custom_field,
+ op,
+ value,
+ validation_prefix="2",
+ )
+
+ # Needed because not all DB backends support Array __contains
+ if (
+ custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
+ and op == "contains"
+ ):
+ return self._parse_atom_doc_link_contains(custom_field, value)
+
+ value_field_name = CustomFieldInstance.get_value_field_name(
+ custom_field.data_type,
+ )
+ if (
+ custom_field.data_type == CustomField.FieldDataType.MONETARY
+ and op in self.EXPR_BY_CATEGORY["arithmetic"]
+ ):
+ value_field_name = "value_monetary_amount"
+ has_field = Q(custom_fields__field=custom_field)
+
+ # We need to use an annotation here because different atoms
+ # might be referring to different instances of custom fields.
+ annotation_name = f"_custom_field_filter_{len(self._annotations)}"
+
+ # Our special exists operator.
+ if op == "exists":
+ annotation = Count("custom_fields", filter=has_field)
+ # A Document should have > 0 match if it has this field, or 0 if doesn't.
+ query_op = "gt" if value else "exact"
+ query = Q(**{f"{annotation_name}__{query_op}": 0})
+ else:
+ # Check if 1) custom field name matches, and 2) value satisfies condition
+ field_filter = has_field & Q(
+ **{f"custom_fields__{value_field_name}__{op}": value},
+ )
+ # Annotate how many matching custom fields each document has
+ annotation = Count("custom_fields", filter=field_filter)
+ # Filter document by count
+ query = Q(**{f"{annotation_name}__gt": 0})
+
+ self._annotations[annotation_name] = annotation
+ return query
+
+ @handle_validation_prefix
+ def _get_custom_field(self, id_or_name):
+ """Get the CustomField instance by id or name."""
+ if id_or_name in self._custom_fields:
+ return self._custom_fields[id_or_name]
+
+ kwargs = (
+ {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
+ )
+ try:
+ custom_field = CustomField.objects.get(**kwargs)
+ except CustomField.DoesNotExist:
+ raise serializers.ValidationError(
+ [_("{name!r} is not a valid custom field.").format(name=id_or_name)],
+ )
+ self._custom_fields[custom_field.id] = custom_field
+ self._custom_fields[custom_field.name] = custom_field
+ return custom_field
+
+ @staticmethod
+ def _split_op(full_op):
+ *prefix, op = str(full_op).rsplit("__", maxsplit=1)
+ prefix = prefix[0] if prefix else None
+ return prefix, op
+
+ @handle_validation_prefix
+ def _validate_atom_op(self, custom_field, raw_op):
+ """Check if the `op` is compatible with the type of the custom field."""
+ prefix, op = self._split_op(raw_op)
+
+ # Check if the operator is supported for the current data_type.
+ supported = False
+ for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
+ if op in self.EXPR_BY_CATEGORY[category]:
+ supported = True
+ break
+
+ # Check prefix
+ if prefix is not None:
+ if (
+ prefix in self.DATE_COMPONENTS
+ and custom_field.data_type == CustomField.FieldDataType.DATE
+ ):
+ pass # ok - e.g., "year__exact" for date field
+ else:
+ supported = False # anything else is invalid
+
+ if not supported:
+ raise serializers.ValidationError(
+ [
+ _("{data_type} does not support query expr {expr!r}.").format(
+ data_type=custom_field.data_type,
+ expr=raw_op,
+ ),
+ ],
+ )
+
+ return raw_op
+
+ def _get_serializer_field(self, custom_field, full_op):
+ """Return a serializers.Field for value validation."""
+ prefix, op = self._split_op(full_op)
+ field = None
+
+ if op in ("isnull", "exists"):
+ # `isnull` takes either True or False regardless of the data_type.
+ field = serializers.BooleanField()
+ elif (
+ custom_field.data_type == CustomField.FieldDataType.DATE
+ and prefix in self.DATE_COMPONENTS
+ ):
+ # DateField admits queries in the form of `year__exact`, etc. These take integers.
+ field = serializers.IntegerField()
+ elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
+ # We can be more specific here and make sure the value is a list.
+ field = serializers.ListField(child=serializers.IntegerField())
+ elif custom_field.data_type == CustomField.FieldDataType.SELECT:
+ # We use this custom field to permit SELECT option names.
+ field = SelectField(custom_field)
+ elif custom_field.data_type == CustomField.FieldDataType.URL:
+ # For URL fields we don't need to be strict about validation (e.g., for istartswith).
+ field = serializers.CharField()
+ else:
+ # The general case: inferred from the corresponding field in CustomFieldInstance.
+ value_field_name = CustomFieldInstance.get_value_field_name(
+ custom_field.data_type,
+ )
+ model_field = CustomFieldInstance._meta.get_field(value_field_name)
+ field_name = model_field.deconstruct()[0]
+ field_class, field_kwargs = self._model_serializer.build_standard_field(
+ field_name,
+ model_field,
+ )
+ field = field_class(**field_kwargs)
+ field.allow_null = False
+
+ # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
+ # See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
+ if isinstance(field, serializers.CharField):
+ field.allow_blank = True
+
+ if op == "in":
+ # `in` takes a list of values.
+ field = serializers.ListField(child=field, allow_empty=False)
+ elif op == "range":
+ # `range` takes a list of values, i.e., [start, end].
+ field = serializers.ListField(
+ child=field,
+ min_length=2,
+ max_length=2,
+ )
+
+ return field
+
+ @handle_validation_prefix
+ def _validate_atom_value(self, custom_field, op, value):
+ """Check if `value` is valid for the custom field and `op`. Returns the validated value."""
+ serializer_field = self._get_serializer_field(custom_field, op)
+ return serializer_field.run_validation(value)
+
+ def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
+ """
+ Handles document link `contains` in a way that is supported by all DB backends.
+ """
+
+ # If the value is an empty set,
+ # this is trivially true for any document with not null document links.
+ if not value:
+ return Q(
+ custom_fields__field=custom_field,
+ custom_fields__value_document_ids__isnull=False,
+ )
+
+ # First we look up reverse links from the requested documents.
+ links = CustomFieldInstance.objects.filter(
+ document_id__in=value,
+ field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
+ )
+
+ # Check if any of the requested IDs are missing.
+ missing_ids = set(value) - set(link.document_id for link in links)
+ if missing_ids:
+ # The result should be an empty set in this case.
+ return Q(id__in=[])
+
+ # Take the intersection of the reverse links - this should be what we are looking for.
+ document_ids_we_want = functools.reduce(
+ operator.and_,
+ (set(link.value_document_ids) for link in links),
+ )
+
+ return Q(id__in=document_ids_we_want)
+
+ @contextmanager
+ def _track_query_depth(self):
+ # guard against queries that are too deeply nested
+ self._current_depth += 1
+ if self._current_depth > self._max_query_depth:
+ raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
+ try:
+ yield
+ finally:
+ self._current_depth -= 1
+
+
+@extend_schema_field(serializers.CharField)
+class CustomFieldQueryFilter(Filter):
+ def __init__(self, validation_prefix):
+ """
+ A filter that filters documents based on custom field name and value.
+
+ Args:
+ validation_prefix: Used to generate the ValidationError message.
+ """
+ super().__init__()
+ self._validation_prefix = validation_prefix
+
+ def filter(self, qs, value):
+ if not value:
+ return qs
+
+ parser = CustomFieldQueryParser(
+ self._validation_prefix,
+ max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
+ max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
+ )
+ q, annotations = parser.parse(value)
+
+ return qs.annotate(**annotations).filter(q)
+
+
+class DocumentFilterSet(FilterSet):
+ is_tagged = BooleanFilter(
+ label="Is tagged",
+ field_name="tags",
+ lookup_expr="isnull",
+ exclude=True,
+ )
+
+ tags__id__all = ObjectFilter(field_name="tags")
+
+ tags__id__none = ObjectFilter(field_name="tags", exclude=True)
+
+ tags__id__in = ObjectFilter(field_name="tags", in_list=True)
+
+ correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
+
+ document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
+
+ storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
+
+ is_in_inbox = InboxFilter()
+
+ title_content = TitleContentFilter()
+
+ owner__id__none = ObjectFilter(field_name="owner", exclude=True)
+
+ custom_fields__icontains = CustomFieldsFilter()
+
+ custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
+
+ custom_fields__id__none = ObjectFilter(
+ field_name="custom_fields__field",
+ exclude=True,
+ )
+
+ custom_fields__id__in = ObjectFilter(
+ field_name="custom_fields__field",
+ in_list=True,
+ )
+
+ has_custom_fields = BooleanFilter(
+ label="Has custom field",
+ field_name="custom_fields",
+ lookup_expr="isnull",
+ exclude=True,
+ )
+
+ custom_field_query = CustomFieldQueryFilter("custom_field_query")
+
+ shared_by__id = SharedByUser()
+
+ mime_type = MimeTypeFilter()
+
+ class Meta:
+ model = Document
+ fields = {
+ "id": ID_KWARGS,
+ "title": CHAR_KWARGS,
+ "content": CHAR_KWARGS,
+ "archive_serial_number": INT_KWARGS,
+ "created": DATE_KWARGS,
+ "added": DATE_KWARGS,
+ "modified": DATE_KWARGS,
+ "original_filename": CHAR_KWARGS,
+ "checksum": CHAR_KWARGS,
+ "correspondent": ["isnull"],
+ "correspondent__id": ID_KWARGS,
+ "correspondent__name": CHAR_KWARGS,
+ "tags__id": ID_KWARGS,
+ "tags__name": CHAR_KWARGS,
+ "document_type": ["isnull"],
+ "document_type__id": ID_KWARGS,
+ "document_type__name": CHAR_KWARGS,
+ "storage_path": ["isnull"],
+ "storage_path__id": ID_KWARGS,
+ "storage_path__name": CHAR_KWARGS,
+ "owner": ["isnull"],
+ "owner__id": ID_KWARGS,
+ "custom_fields": ["icontains"],
+ }
+
+
+class ShareLinkFilterSet(FilterSet):
+ class Meta:
+ model = ShareLink
+ fields = {
+ "created": DATE_KWARGS,
+ "expiration": DATE_KWARGS,
+ }
+
+
+class PaperlessTaskFilterSet(FilterSet):
+ acknowledged = BooleanFilter(
+ label="Acknowledged",
+ field_name="acknowledged",
+ )
+
+ class Meta:
+ model = PaperlessTask
+ fields = {
+ "type": ["exact"],
+ "task_name": ["exact"],
+ "status": ["exact"],
+ }
+
+
+class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
+ """
+ A filter backend that limits results to those where the requesting user
+ has read object level permissions, owns the objects, or objects without
+ an owner (for backwards compat)
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ objects_with_perms = super().filter_queryset(request, queryset, view)
+ objects_owned = queryset.filter(owner=request.user)
+ objects_unowned = queryset.filter(owner__isnull=True)
+ return objects_with_perms | objects_owned | objects_unowned
+
+
+class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
+ """
+ A filter backend that limits results to those where the requesting user
+ owns the objects or objects without an owner (for backwards compat)
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ if request.user.is_superuser:
+ return queryset
+ objects_owned = queryset.filter(owner=request.user)
+ objects_unowned = queryset.filter(owner__isnull=True)
+ return objects_owned | objects_unowned
+
+
+class DocumentsOrderingFilter(OrderingFilter):
+ field_name = "ordering"
+ prefix = "custom_field_"
+
+ def filter_queryset(self, request, queryset, view):
+ param = request.query_params.get("ordering")
+ if param and self.prefix in param:
+ custom_field_id = int(param.split(self.prefix)[1])
+ try:
+ field = CustomField.objects.get(pk=custom_field_id)
+ except CustomField.DoesNotExist:
+ raise serializers.ValidationError(
+ {self.prefix + str(custom_field_id): [_("Custom field not found")]},
+ )
+
+ annotation = None
+ match field.data_type:
+ case CustomField.FieldDataType.STRING:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_text")[:1],
+ )
+ case CustomField.FieldDataType.INT:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_int")[:1],
+ )
+ case CustomField.FieldDataType.FLOAT:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_float")[:1],
+ )
+ case CustomField.FieldDataType.DATE:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_date")[:1],
+ )
+ case CustomField.FieldDataType.MONETARY:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_monetary_amount")[:1],
+ )
+ case CustomField.FieldDataType.SELECT:
+ # Select options are a little more complicated since the value is the id of the option, not
+ # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
+ # case statement for each option, setting the value to the index of the option in a list
+ # sorted by label, and then summing the results to give a single value for the annotation
+
+ select_options = sorted(
+ field.extra_data.get("select_options", []),
+ key=lambda x: x.get("label"),
+ )
+ whens = [
+ When(
+ custom_fields__field_id=custom_field_id,
+ custom_fields__value_select=option.get("id"),
+ then=Value(idx, output_field=IntegerField()),
+ )
+ for idx, option in enumerate(select_options)
+ ]
+ whens.append(
+ When(
+ custom_fields__field_id=custom_field_id,
+ custom_fields__value_select__isnull=True,
+ then=Value(
+ len(select_options),
+ output_field=IntegerField(),
+ ),
+ ),
+ )
+ annotation = Sum(
+ Case(
+ *whens,
+ default=Value(0),
+ output_field=IntegerField(),
+ ),
+ )
+ case CustomField.FieldDataType.DOCUMENTLINK:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_document_ids")[:1],
+ )
+ case CustomField.FieldDataType.URL:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_url")[:1],
+ )
+ case CustomField.FieldDataType.BOOL:
+ annotation = Subquery(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ).values("value_bool")[:1],
+ )
+
+ if not annotation:
+ # Only happens if a new data type is added and not handled here
+ raise ValueError("Invalid custom field data type")
+
+ queryset = (
+ queryset.annotate(
+ # We need to annotate the queryset with the custom field value
+ custom_field_value=annotation,
+ # We also need to annotate the queryset with a boolean for sorting whether the field exists
+ has_field=Exists(
+ CustomFieldInstance.objects.filter(
+ document_id=OuterRef("id"),
+ field_id=custom_field_id,
+ ),
+ ),
+ )
+ .order_by(
+ "-has_field",
+ param.replace(
+ self.prefix + str(custom_field_id),
+ "custom_field_value",
+ ),
+ )
+ .distinct()
+ )
+
+ return super().filter_queryset(request, queryset, view)
from rest_framework.viewsets import ReadOnlyModelViewSet
from rest_framework.viewsets import ViewSet
-from documents.filters import CorrespondentFilterSet
-from documents.filters import CustomFieldFilterSet
-from documents.filters import DocumentFilterSet
-from documents.filters import DocumentsOrderingFilter
-from documents.filters import DocumentTypeFilterSet
-from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
-from documents.filters import ObjectOwnedPermissionsFilter
-from documents.filters import PaperlessTaskFilterSet
-from documents.filters import ShareLinkFilterSet
-from documents.filters import StoragePathFilterSet
-from documents.filters import TagFilterSet
from documents.schema import generate_object_with_permissions_schema
from documents.signals import document_updated
from documents.templating.filepath import validate_filepath_template_and_render
from paperless.data_models import DocumentMetadataOverrides
from paperless.data_models import DocumentSource
from paperless.db import GnuPG
+from paperless.filters import CorrespondentFilterSet
+from paperless.filters import CustomFieldFilterSet
+from paperless.filters import DocumentFilterSet
+from paperless.filters import DocumentsOrderingFilter
+from paperless.filters import DocumentTypeFilterSet
from paperless.filters import GroupFilterSet
+from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
+from paperless.filters import ObjectOwnedPermissionsFilter
+from paperless.filters import PaperlessTaskFilterSet
+from paperless.filters import ShareLinkFilterSet
+from paperless.filters import StoragePathFilterSet
+from paperless.filters import TagFilterSet
from paperless.filters import UserFilterSet
from paperless.index import DelayedQuery
from paperless.mail import send_email
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
-from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
+from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
from paperless.permissions import PaperlessObjectPermissions
from paperless.views import PassUserMixin
from paperless.views import StandardPagination