]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Merge filters
authorshamoon <4887959+shamoon@users.noreply.github.com>
Tue, 8 Apr 2025 23:38:22 +0000 (16:38 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Tue, 8 Apr 2025 23:38:22 +0000 (16:38 -0700)
src/documents/filters.py [deleted file]
src/paperless/filters.py
src/paperless/views.py
src/paperless_mail/views.py

diff --git a/src/documents/filters.py b/src/documents/filters.py
deleted file mode 100644 (file)
index a1c9917..0000000
+++ /dev/null
@@ -1,950 +0,0 @@
-from __future__ import annotations
-
-import functools
-import inspect
-import json
-import operator
-from contextlib import contextmanager
-from typing import TYPE_CHECKING
-
-from django.contrib.contenttypes.models import ContentType
-from django.db.models import Case
-from django.db.models import CharField
-from django.db.models import Count
-from django.db.models import Exists
-from django.db.models import IntegerField
-from django.db.models import OuterRef
-from django.db.models import Q
-from django.db.models import Subquery
-from django.db.models import Sum
-from django.db.models import Value
-from django.db.models import When
-from django.db.models.functions import Cast
-from django.utils.translation import gettext_lazy as _
-from django_filters.rest_framework import BooleanFilter
-from django_filters.rest_framework import Filter
-from django_filters.rest_framework import FilterSet
-from drf_spectacular.utils import extend_schema_field
-from guardian.utils import get_group_obj_perms_model
-from guardian.utils import get_user_obj_perms_model
-from rest_framework import serializers
-from rest_framework.filters import OrderingFilter
-from rest_framework_guardian.filters import ObjectPermissionsFilter
-
-from paperless.models import Correspondent
-from paperless.models import CustomField
-from paperless.models import CustomFieldInstance
-from paperless.models import Document
-from paperless.models import DocumentType
-from paperless.models import PaperlessTask
-from paperless.models import ShareLink
-from paperless.models import StoragePath
-from paperless.models import Tag
-
-if TYPE_CHECKING:
-    from collections.abc import Callable
-
-CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
-ID_KWARGS = ["in", "exact"]
-INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
-DATE_KWARGS = [
-    "year",
-    "month",
-    "day",
-    "date__gt",
-    "date__gte",
-    "gt",
-    "gte",
-    "date__lt",
-    "date__lte",
-    "lt",
-    "lte",
-]
-
-CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
-CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
-
-
-class CorrespondentFilterSet(FilterSet):
-    class Meta:
-        model = Correspondent
-        fields = {
-            "id": ID_KWARGS,
-            "name": CHAR_KWARGS,
-        }
-
-
-class TagFilterSet(FilterSet):
-    class Meta:
-        model = Tag
-        fields = {
-            "id": ID_KWARGS,
-            "name": CHAR_KWARGS,
-        }
-
-
-class DocumentTypeFilterSet(FilterSet):
-    class Meta:
-        model = DocumentType
-        fields = {
-            "id": ID_KWARGS,
-            "name": CHAR_KWARGS,
-        }
-
-
-class StoragePathFilterSet(FilterSet):
-    class Meta:
-        model = StoragePath
-        fields = {
-            "id": ID_KWARGS,
-            "name": CHAR_KWARGS,
-            "path": CHAR_KWARGS,
-        }
-
-
-class ObjectFilter(Filter):
-    def __init__(self, *, exclude=False, in_list=False, field_name=""):
-        super().__init__()
-        self.exclude = exclude
-        self.in_list = in_list
-        self.field_name = field_name
-
-    def filter(self, qs, value):
-        if not value:
-            return qs
-
-        try:
-            object_ids = [int(x) for x in value.split(",")]
-        except ValueError:
-            return qs
-
-        if self.in_list:
-            qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
-        else:
-            for obj_id in object_ids:
-                if self.exclude:
-                    qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
-                else:
-                    qs = qs.filter(**{f"{self.field_name}__id": obj_id})
-
-        return qs
-
-
-@extend_schema_field(serializers.BooleanField)
-class InboxFilter(Filter):
-    def filter(self, qs, value):
-        if value == "true":
-            return qs.filter(tags__is_inbox_tag=True)
-        elif value == "false":
-            return qs.exclude(tags__is_inbox_tag=True)
-        else:
-            return qs
-
-
-@extend_schema_field(serializers.CharField)
-class TitleContentFilter(Filter):
-    def filter(self, qs, value):
-        if value:
-            return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
-        else:
-            return qs
-
-
-@extend_schema_field(serializers.BooleanField)
-class SharedByUser(Filter):
-    def filter(self, qs, value):
-        ctype = ContentType.objects.get_for_model(self.model)
-        UserObjectPermission = get_user_obj_perms_model()
-        GroupObjectPermission = get_group_obj_perms_model()
-        # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
-        # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
-        return (
-            qs.filter(
-                owner_id=value,
-            )
-            .annotate(
-                num_shared_users=Count(
-                    UserObjectPermission.objects.filter(
-                        content_type=ctype,
-                        object_pk=Cast(OuterRef("pk"), CharField()),
-                    ).values("user_id")[:1],
-                ),
-            )
-            .annotate(
-                num_shared_groups=Count(
-                    GroupObjectPermission.objects.filter(
-                        content_type=ctype,
-                        object_pk=Cast(OuterRef("pk"), CharField()),
-                    ).values("group_id")[:1],
-                ),
-            )
-            .filter(
-                Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
-            )
-            if value is not None
-            else qs
-        )
-
-
-class CustomFieldFilterSet(FilterSet):
-    class Meta:
-        model = CustomField
-        fields = {
-            "id": ID_KWARGS,
-            "name": CHAR_KWARGS,
-        }
-
-
-@extend_schema_field(serializers.CharField)
-class CustomFieldsFilter(Filter):
-    def filter(self, qs, value):
-        if value:
-            fields_with_matching_selects = CustomField.objects.filter(
-                extra_data__icontains=value,
-            )
-            option_ids = []
-            if fields_with_matching_selects.count() > 0:
-                for field in fields_with_matching_selects:
-                    options = field.extra_data.get("select_options", [])
-                    for _, option in enumerate(options):
-                        if option.get("label").lower().find(value.lower()) != -1:
-                            option_ids.extend([option.get("id")])
-            return (
-                qs.filter(custom_fields__field__name__icontains=value)
-                | qs.filter(custom_fields__value_text__icontains=value)
-                | qs.filter(custom_fields__value_bool__icontains=value)
-                | qs.filter(custom_fields__value_int__icontains=value)
-                | qs.filter(custom_fields__value_float__icontains=value)
-                | qs.filter(custom_fields__value_date__icontains=value)
-                | qs.filter(custom_fields__value_url__icontains=value)
-                | qs.filter(custom_fields__value_monetary__icontains=value)
-                | qs.filter(custom_fields__value_document_ids__icontains=value)
-                | qs.filter(custom_fields__value_select__in=option_ids)
-            )
-        else:
-            return qs
-
-
-class MimeTypeFilter(Filter):
-    def filter(self, qs, value):
-        if value:
-            return qs.filter(mime_type__icontains=value)
-        else:
-            return qs
-
-
-class SelectField(serializers.CharField):
-    def __init__(self, custom_field: CustomField):
-        self._options = custom_field.extra_data["select_options"]
-        super().__init__(max_length=16)
-
-    def to_internal_value(self, data):
-        # If the supplied value is the option label instead of the ID
-        try:
-            data = next(
-                option.get("id")
-                for option in self._options
-                if option.get("label") == data
-            )
-        except StopIteration:
-            pass
-        return super().to_internal_value(data)
-
-
-def handle_validation_prefix(func: Callable):
-    """
-    Catch ValidationErrors raised by the wrapped function
-    and add a prefix to the exception detail to track what causes the exception,
-    similar to nested serializers.
-    """
-
-    def wrapper(*args, validation_prefix=None, **kwargs):
-        try:
-            return func(*args, **kwargs)
-        except serializers.ValidationError as e:
-            raise serializers.ValidationError({validation_prefix: e.detail})
-
-    # Update the signature to include the validation_prefix argument
-    old_sig = inspect.signature(func)
-    new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
-    new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
-
-    # Apply functools.wraps and manually set the new signature
-    functools.update_wrapper(wrapper, func)
-    wrapper.__signature__ = new_sig
-
-    return wrapper
-
-
-class CustomFieldQueryParser:
-    EXPR_BY_CATEGORY = {
-        "basic": ["exact", "in", "isnull", "exists"],
-        "string": [
-            "icontains",
-            "istartswith",
-            "iendswith",
-        ],
-        "arithmetic": [
-            "gt",
-            "gte",
-            "lt",
-            "lte",
-            "range",
-        ],
-        "containment": ["contains"],
-    }
-
-    SUPPORTED_EXPR_CATEGORIES = {
-        CustomField.FieldDataType.STRING: ("basic", "string"),
-        CustomField.FieldDataType.URL: ("basic", "string"),
-        CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
-        CustomField.FieldDataType.BOOL: ("basic",),
-        CustomField.FieldDataType.INT: ("basic", "arithmetic"),
-        CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
-        CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
-        CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
-        CustomField.FieldDataType.SELECT: ("basic",),
-    }
-
-    DATE_COMPONENTS = [
-        "year",
-        "iso_year",
-        "month",
-        "day",
-        "week",
-        "week_day",
-        "iso_week_day",
-        "quarter",
-    ]
-
-    def __init__(
-        self,
-        validation_prefix,
-        max_query_depth=10,
-        max_atom_count=20,
-    ) -> None:
-        """
-        A helper class that parses the query string into a `django.db.models.Q` for filtering
-        documents based on custom field values.
-
-        The syntax of the query expression is illustrated with the below pseudo code rules:
-        1. parse([`custom_field`, "exists", true]):
-            matches documents with Q(custom_fields__field=`custom_field`)
-        2. parse([`custom_field`, "exists", false]):
-            matches documents with ~Q(custom_fields__field=`custom_field`)
-        3. parse([`custom_field`, `op`, `value`]):
-            matches documents with
-            Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
-        4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
-            -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
-        5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
-            -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
-        6. parse(["NOT", `q`])
-            -> ~parse(`q`)
-
-        Args:
-            validation_prefix: Used to generate the ValidationError message.
-            max_query_depth: Limits the maximum nesting depth of queries.
-            max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
-
-        `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
-        complex SQL queries.
-        """
-        self._custom_fields: dict[int | str, CustomField] = {}
-        self._validation_prefix = validation_prefix
-        # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
-        self._model_serializer = serializers.ModelSerializer()
-        # Used for sanity check
-        self._max_query_depth = max_query_depth
-        self._max_atom_count = max_atom_count
-        self._current_depth = 0
-        self._atom_count = 0
-        # The set of annotations that we need to apply to the queryset
-        self._annotations = {}
-
-    def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
-        """
-        Parses the query string into a `django.db.models.Q`
-        and a set of annotations to be applied to the queryset.
-        """
-        try:
-            expr = json.loads(query)
-        except json.JSONDecodeError:
-            raise serializers.ValidationError(
-                {self._validation_prefix: [_("Value must be valid JSON.")]},
-            )
-        return (
-            self._parse_expr(expr, validation_prefix=self._validation_prefix),
-            self._annotations,
-        )
-
-    @handle_validation_prefix
-    def _parse_expr(self, expr) -> Q:
-        """
-        Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
-        """
-        with self._track_query_depth():
-            if isinstance(expr, list | tuple):
-                if len(expr) == 2:
-                    return self._parse_logical_expr(*expr)
-                elif len(expr) == 3:
-                    return self._parse_atom(*expr)
-            raise serializers.ValidationError(
-                [_("Invalid custom field query expression")],
-            )
-
-    @handle_validation_prefix
-    def _parse_expr_list(self, exprs) -> list[Q]:
-        """
-        Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
-        """
-        if not isinstance(exprs, list | tuple) or not exprs:
-            raise serializers.ValidationError(
-                [_("Invalid expression list. Must be nonempty.")],
-            )
-        return [
-            self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
-        ]
-
-    def _parse_logical_expr(self, op, args) -> Q:
-        """
-        Handles rule 4, 5, 6.
-        """
-        op_lower = op.lower()
-
-        if op_lower == "not":
-            return ~self._parse_expr(args, validation_prefix=1)
-
-        if op_lower == "and":
-            op_func = operator.and_
-        elif op_lower == "or":
-            op_func = operator.or_
-        else:
-            raise serializers.ValidationError(
-                {"0": [_("Invalid logical operator {op!r}").format(op=op)]},
-            )
-
-        qs = self._parse_expr_list(args, validation_prefix="1")
-        return functools.reduce(op_func, qs)
-
-    def _parse_atom(self, id_or_name, op, value) -> Q:
-        """
-        Handles rule 1, 2, 3.
-        """
-        # Guard against queries with too many conditions.
-        self._atom_count += 1
-        if self._atom_count > self._max_atom_count:
-            raise serializers.ValidationError(
-                [_("Maximum number of query conditions exceeded.")],
-            )
-
-        custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
-        op = self._validate_atom_op(custom_field, op, validation_prefix="1")
-        value = self._validate_atom_value(
-            custom_field,
-            op,
-            value,
-            validation_prefix="2",
-        )
-
-        # Needed because not all DB backends support Array __contains
-        if (
-            custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
-            and op == "contains"
-        ):
-            return self._parse_atom_doc_link_contains(custom_field, value)
-
-        value_field_name = CustomFieldInstance.get_value_field_name(
-            custom_field.data_type,
-        )
-        if (
-            custom_field.data_type == CustomField.FieldDataType.MONETARY
-            and op in self.EXPR_BY_CATEGORY["arithmetic"]
-        ):
-            value_field_name = "value_monetary_amount"
-        has_field = Q(custom_fields__field=custom_field)
-
-        # We need to use an annotation here because different atoms
-        # might be referring to different instances of custom fields.
-        annotation_name = f"_custom_field_filter_{len(self._annotations)}"
-
-        # Our special exists operator.
-        if op == "exists":
-            annotation = Count("custom_fields", filter=has_field)
-            # A Document should have > 0 match if it has this field, or 0 if doesn't.
-            query_op = "gt" if value else "exact"
-            query = Q(**{f"{annotation_name}__{query_op}": 0})
-        else:
-            # Check if 1) custom field name matches, and 2) value satisfies condition
-            field_filter = has_field & Q(
-                **{f"custom_fields__{value_field_name}__{op}": value},
-            )
-            # Annotate how many matching custom fields each document has
-            annotation = Count("custom_fields", filter=field_filter)
-            # Filter document by count
-            query = Q(**{f"{annotation_name}__gt": 0})
-
-        self._annotations[annotation_name] = annotation
-        return query
-
-    @handle_validation_prefix
-    def _get_custom_field(self, id_or_name):
-        """Get the CustomField instance by id or name."""
-        if id_or_name in self._custom_fields:
-            return self._custom_fields[id_or_name]
-
-        kwargs = (
-            {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
-        )
-        try:
-            custom_field = CustomField.objects.get(**kwargs)
-        except CustomField.DoesNotExist:
-            raise serializers.ValidationError(
-                [_("{name!r} is not a valid custom field.").format(name=id_or_name)],
-            )
-        self._custom_fields[custom_field.id] = custom_field
-        self._custom_fields[custom_field.name] = custom_field
-        return custom_field
-
-    @staticmethod
-    def _split_op(full_op):
-        *prefix, op = str(full_op).rsplit("__", maxsplit=1)
-        prefix = prefix[0] if prefix else None
-        return prefix, op
-
-    @handle_validation_prefix
-    def _validate_atom_op(self, custom_field, raw_op):
-        """Check if the `op` is compatible with the type of the custom field."""
-        prefix, op = self._split_op(raw_op)
-
-        # Check if the operator is supported for the current data_type.
-        supported = False
-        for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
-            if op in self.EXPR_BY_CATEGORY[category]:
-                supported = True
-                break
-
-        # Check prefix
-        if prefix is not None:
-            if (
-                prefix in self.DATE_COMPONENTS
-                and custom_field.data_type == CustomField.FieldDataType.DATE
-            ):
-                pass  # ok - e.g., "year__exact" for date field
-            else:
-                supported = False  # anything else is invalid
-
-        if not supported:
-            raise serializers.ValidationError(
-                [
-                    _("{data_type} does not support query expr {expr!r}.").format(
-                        data_type=custom_field.data_type,
-                        expr=raw_op,
-                    ),
-                ],
-            )
-
-        return raw_op
-
-    def _get_serializer_field(self, custom_field, full_op):
-        """Return a serializers.Field for value validation."""
-        prefix, op = self._split_op(full_op)
-        field = None
-
-        if op in ("isnull", "exists"):
-            # `isnull` takes either True or False regardless of the data_type.
-            field = serializers.BooleanField()
-        elif (
-            custom_field.data_type == CustomField.FieldDataType.DATE
-            and prefix in self.DATE_COMPONENTS
-        ):
-            # DateField admits queries in the form of `year__exact`, etc. These take integers.
-            field = serializers.IntegerField()
-        elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
-            # We can be more specific here and make sure the value is a list.
-            field = serializers.ListField(child=serializers.IntegerField())
-        elif custom_field.data_type == CustomField.FieldDataType.SELECT:
-            # We use this custom field to permit SELECT option names.
-            field = SelectField(custom_field)
-        elif custom_field.data_type == CustomField.FieldDataType.URL:
-            # For URL fields we don't need to be strict about validation (e.g., for istartswith).
-            field = serializers.CharField()
-        else:
-            # The general case: inferred from the corresponding field in CustomFieldInstance.
-            value_field_name = CustomFieldInstance.get_value_field_name(
-                custom_field.data_type,
-            )
-            model_field = CustomFieldInstance._meta.get_field(value_field_name)
-            field_name = model_field.deconstruct()[0]
-            field_class, field_kwargs = self._model_serializer.build_standard_field(
-                field_name,
-                model_field,
-            )
-            field = field_class(**field_kwargs)
-            field.allow_null = False
-
-            # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
-            # See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
-            if isinstance(field, serializers.CharField):
-                field.allow_blank = True
-
-        if op == "in":
-            # `in` takes a list of values.
-            field = serializers.ListField(child=field, allow_empty=False)
-        elif op == "range":
-            # `range` takes a list of values, i.e., [start, end].
-            field = serializers.ListField(
-                child=field,
-                min_length=2,
-                max_length=2,
-            )
-
-        return field
-
-    @handle_validation_prefix
-    def _validate_atom_value(self, custom_field, op, value):
-        """Check if `value` is valid for the custom field and `op`. Returns the validated value."""
-        serializer_field = self._get_serializer_field(custom_field, op)
-        return serializer_field.run_validation(value)
-
-    def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
-        """
-        Handles document link `contains` in a way that is supported by all DB backends.
-        """
-
-        # If the value is an empty set,
-        # this is trivially true for any document with not null document links.
-        if not value:
-            return Q(
-                custom_fields__field=custom_field,
-                custom_fields__value_document_ids__isnull=False,
-            )
-
-        # First we look up reverse links from the requested documents.
-        links = CustomFieldInstance.objects.filter(
-            document_id__in=value,
-            field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
-        )
-
-        # Check if any of the requested IDs are missing.
-        missing_ids = set(value) - set(link.document_id for link in links)
-        if missing_ids:
-            # The result should be an empty set in this case.
-            return Q(id__in=[])
-
-        # Take the intersection of the reverse links - this should be what we are looking for.
-        document_ids_we_want = functools.reduce(
-            operator.and_,
-            (set(link.value_document_ids) for link in links),
-        )
-
-        return Q(id__in=document_ids_we_want)
-
-    @contextmanager
-    def _track_query_depth(self):
-        # guard against queries that are too deeply nested
-        self._current_depth += 1
-        if self._current_depth > self._max_query_depth:
-            raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
-        try:
-            yield
-        finally:
-            self._current_depth -= 1
-
-
-@extend_schema_field(serializers.CharField)
-class CustomFieldQueryFilter(Filter):
-    def __init__(self, validation_prefix):
-        """
-        A filter that filters documents based on custom field name and value.
-
-        Args:
-            validation_prefix: Used to generate the ValidationError message.
-        """
-        super().__init__()
-        self._validation_prefix = validation_prefix
-
-    def filter(self, qs, value):
-        if not value:
-            return qs
-
-        parser = CustomFieldQueryParser(
-            self._validation_prefix,
-            max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
-            max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
-        )
-        q, annotations = parser.parse(value)
-
-        return qs.annotate(**annotations).filter(q)
-
-
-class DocumentFilterSet(FilterSet):
-    is_tagged = BooleanFilter(
-        label="Is tagged",
-        field_name="tags",
-        lookup_expr="isnull",
-        exclude=True,
-    )
-
-    tags__id__all = ObjectFilter(field_name="tags")
-
-    tags__id__none = ObjectFilter(field_name="tags", exclude=True)
-
-    tags__id__in = ObjectFilter(field_name="tags", in_list=True)
-
-    correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
-
-    document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
-
-    storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
-
-    is_in_inbox = InboxFilter()
-
-    title_content = TitleContentFilter()
-
-    owner__id__none = ObjectFilter(field_name="owner", exclude=True)
-
-    custom_fields__icontains = CustomFieldsFilter()
-
-    custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
-
-    custom_fields__id__none = ObjectFilter(
-        field_name="custom_fields__field",
-        exclude=True,
-    )
-
-    custom_fields__id__in = ObjectFilter(
-        field_name="custom_fields__field",
-        in_list=True,
-    )
-
-    has_custom_fields = BooleanFilter(
-        label="Has custom field",
-        field_name="custom_fields",
-        lookup_expr="isnull",
-        exclude=True,
-    )
-
-    custom_field_query = CustomFieldQueryFilter("custom_field_query")
-
-    shared_by__id = SharedByUser()
-
-    mime_type = MimeTypeFilter()
-
-    class Meta:
-        model = Document
-        fields = {
-            "id": ID_KWARGS,
-            "title": CHAR_KWARGS,
-            "content": CHAR_KWARGS,
-            "archive_serial_number": INT_KWARGS,
-            "created": DATE_KWARGS,
-            "added": DATE_KWARGS,
-            "modified": DATE_KWARGS,
-            "original_filename": CHAR_KWARGS,
-            "checksum": CHAR_KWARGS,
-            "correspondent": ["isnull"],
-            "correspondent__id": ID_KWARGS,
-            "correspondent__name": CHAR_KWARGS,
-            "tags__id": ID_KWARGS,
-            "tags__name": CHAR_KWARGS,
-            "document_type": ["isnull"],
-            "document_type__id": ID_KWARGS,
-            "document_type__name": CHAR_KWARGS,
-            "storage_path": ["isnull"],
-            "storage_path__id": ID_KWARGS,
-            "storage_path__name": CHAR_KWARGS,
-            "owner": ["isnull"],
-            "owner__id": ID_KWARGS,
-            "custom_fields": ["icontains"],
-        }
-
-
-class ShareLinkFilterSet(FilterSet):
-    class Meta:
-        model = ShareLink
-        fields = {
-            "created": DATE_KWARGS,
-            "expiration": DATE_KWARGS,
-        }
-
-
-class PaperlessTaskFilterSet(FilterSet):
-    acknowledged = BooleanFilter(
-        label="Acknowledged",
-        field_name="acknowledged",
-    )
-
-    class Meta:
-        model = PaperlessTask
-        fields = {
-            "type": ["exact"],
-            "task_name": ["exact"],
-            "status": ["exact"],
-        }
-
-
-class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
-    """
-    A filter backend that limits results to those where the requesting user
-    has read object level permissions, owns the objects, or objects without
-    an owner (for backwards compat)
-    """
-
-    def filter_queryset(self, request, queryset, view):
-        objects_with_perms = super().filter_queryset(request, queryset, view)
-        objects_owned = queryset.filter(owner=request.user)
-        objects_unowned = queryset.filter(owner__isnull=True)
-        return objects_with_perms | objects_owned | objects_unowned
-
-
-class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
-    """
-    A filter backend that limits results to those where the requesting user
-    owns the objects or objects without an owner (for backwards compat)
-    """
-
-    def filter_queryset(self, request, queryset, view):
-        if request.user.is_superuser:
-            return queryset
-        objects_owned = queryset.filter(owner=request.user)
-        objects_unowned = queryset.filter(owner__isnull=True)
-        return objects_owned | objects_unowned
-
-
-class DocumentsOrderingFilter(OrderingFilter):
-    field_name = "ordering"
-    prefix = "custom_field_"
-
-    def filter_queryset(self, request, queryset, view):
-        param = request.query_params.get("ordering")
-        if param and self.prefix in param:
-            custom_field_id = int(param.split(self.prefix)[1])
-            try:
-                field = CustomField.objects.get(pk=custom_field_id)
-            except CustomField.DoesNotExist:
-                raise serializers.ValidationError(
-                    {self.prefix + str(custom_field_id): [_("Custom field not found")]},
-                )
-
-            annotation = None
-            match field.data_type:
-                case CustomField.FieldDataType.STRING:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_text")[:1],
-                    )
-                case CustomField.FieldDataType.INT:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_int")[:1],
-                    )
-                case CustomField.FieldDataType.FLOAT:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_float")[:1],
-                    )
-                case CustomField.FieldDataType.DATE:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_date")[:1],
-                    )
-                case CustomField.FieldDataType.MONETARY:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_monetary_amount")[:1],
-                    )
-                case CustomField.FieldDataType.SELECT:
-                    # Select options are a little more complicated since the value is the id of the option, not
-                    # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
-                    # case statement for each option, setting the value to the index of the option in a list
-                    # sorted by label, and then summing the results to give a single value for the annotation
-
-                    select_options = sorted(
-                        field.extra_data.get("select_options", []),
-                        key=lambda x: x.get("label"),
-                    )
-                    whens = [
-                        When(
-                            custom_fields__field_id=custom_field_id,
-                            custom_fields__value_select=option.get("id"),
-                            then=Value(idx, output_field=IntegerField()),
-                        )
-                        for idx, option in enumerate(select_options)
-                    ]
-                    whens.append(
-                        When(
-                            custom_fields__field_id=custom_field_id,
-                            custom_fields__value_select__isnull=True,
-                            then=Value(
-                                len(select_options),
-                                output_field=IntegerField(),
-                            ),
-                        ),
-                    )
-                    annotation = Sum(
-                        Case(
-                            *whens,
-                            default=Value(0),
-                            output_field=IntegerField(),
-                        ),
-                    )
-                case CustomField.FieldDataType.DOCUMENTLINK:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_document_ids")[:1],
-                    )
-                case CustomField.FieldDataType.URL:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_url")[:1],
-                    )
-                case CustomField.FieldDataType.BOOL:
-                    annotation = Subquery(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ).values("value_bool")[:1],
-                    )
-
-            if not annotation:
-                # Only happens if a new data type is added and not handled here
-                raise ValueError("Invalid custom field data type")
-
-            queryset = (
-                queryset.annotate(
-                    # We need to annotate the queryset with the custom field value
-                    custom_field_value=annotation,
-                    # We also need to annotate the queryset with a boolean for sorting whether the field exists
-                    has_field=Exists(
-                        CustomFieldInstance.objects.filter(
-                            document_id=OuterRef("id"),
-                            field_id=custom_field_id,
-                        ),
-                    ),
-                )
-                .order_by(
-                    "-has_field",
-                    param.replace(
-                        self.prefix + str(custom_field_id),
-                        "custom_field_value",
-                    ),
-                )
-                .distinct()
-            )
-
-        return super().filter_queryset(request, queryset, view)
index a3c09d50fed7df17c2269a81c5af841577679f21..d5f25d6420e925208d648f62c5c4406e56554800 100644 (file)
@@ -1,8 +1,70 @@
+from __future__ import annotations
+
+import functools
+import inspect
+import json
+import operator
+from contextlib import contextmanager
+from typing import TYPE_CHECKING
+
 from django.contrib.auth.models import Group
 from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
+from django.db.models import Case
+from django.db.models import CharField
+from django.db.models import Count
+from django.db.models import Exists
+from django.db.models import IntegerField
+from django.db.models import OuterRef
+from django.db.models import Q
+from django.db.models import Subquery
+from django.db.models import Sum
+from django.db.models import Value
+from django.db.models import When
+from django.db.models.functions import Cast
+from django.utils.translation import gettext_lazy as _
+from django_filters.rest_framework import BooleanFilter
+from django_filters.rest_framework import Filter
 from django_filters.rest_framework import FilterSet
+from drf_spectacular.utils import extend_schema_field
+from guardian.utils import get_group_obj_perms_model
+from guardian.utils import get_user_obj_perms_model
+from rest_framework import serializers
+from rest_framework.filters import OrderingFilter
+from rest_framework_guardian.filters import ObjectPermissionsFilter
+
+from paperless.models import Correspondent
+from paperless.models import CustomField
+from paperless.models import CustomFieldInstance
+from paperless.models import Document
+from paperless.models import DocumentType
+from paperless.models import PaperlessTask
+from paperless.models import ShareLink
+from paperless.models import StoragePath
+from paperless.models import Tag
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
+ID_KWARGS = ["in", "exact"]
+INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
+DATE_KWARGS = [
+    "year",
+    "month",
+    "day",
+    "date__gt",
+    "date__gte",
+    "gt",
+    "gte",
+    "date__lt",
+    "date__lte",
+    "lt",
+    "lte",
+]
 
-from documents.filters import CHAR_KWARGS
+CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
+CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
 
 
 class UserFilterSet(FilterSet):
@@ -15,3 +77,888 @@ class GroupFilterSet(FilterSet):
     class Meta:
         model = Group
         fields = {"name": CHAR_KWARGS}
+
+
+class CorrespondentFilterSet(FilterSet):
+    class Meta:
+        model = Correspondent
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
+class TagFilterSet(FilterSet):
+    class Meta:
+        model = Tag
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
+class DocumentTypeFilterSet(FilterSet):
+    class Meta:
+        model = DocumentType
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
+class StoragePathFilterSet(FilterSet):
+    class Meta:
+        model = StoragePath
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+            "path": CHAR_KWARGS,
+        }
+
+
+class ObjectFilter(Filter):
+    def __init__(self, *, exclude=False, in_list=False, field_name=""):
+        super().__init__()
+        self.exclude = exclude
+        self.in_list = in_list
+        self.field_name = field_name
+
+    def filter(self, qs, value):
+        if not value:
+            return qs
+
+        try:
+            object_ids = [int(x) for x in value.split(",")]
+        except ValueError:
+            return qs
+
+        if self.in_list:
+            qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct()
+        else:
+            for obj_id in object_ids:
+                if self.exclude:
+                    qs = qs.exclude(**{f"{self.field_name}__id": obj_id})
+                else:
+                    qs = qs.filter(**{f"{self.field_name}__id": obj_id})
+
+        return qs
+
+
+@extend_schema_field(serializers.BooleanField)
+class InboxFilter(Filter):
+    def filter(self, qs, value):
+        if value == "true":
+            return qs.filter(tags__is_inbox_tag=True)
+        elif value == "false":
+            return qs.exclude(tags__is_inbox_tag=True)
+        else:
+            return qs
+
+
+@extend_schema_field(serializers.CharField)
+class TitleContentFilter(Filter):
+    def filter(self, qs, value):
+        if value:
+            return qs.filter(Q(title__icontains=value) | Q(content__icontains=value))
+        else:
+            return qs
+
+
+@extend_schema_field(serializers.BooleanField)
+class SharedByUser(Filter):
+    def filter(self, qs, value):
+        ctype = ContentType.objects.get_for_model(self.model)
+        UserObjectPermission = get_user_obj_perms_model()
+        GroupObjectPermission = get_group_obj_perms_model()
+        # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries
+        # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0
+        return (
+            qs.filter(
+                owner_id=value,
+            )
+            .annotate(
+                num_shared_users=Count(
+                    UserObjectPermission.objects.filter(
+                        content_type=ctype,
+                        object_pk=Cast(OuterRef("pk"), CharField()),
+                    ).values("user_id")[:1],
+                ),
+            )
+            .annotate(
+                num_shared_groups=Count(
+                    GroupObjectPermission.objects.filter(
+                        content_type=ctype,
+                        object_pk=Cast(OuterRef("pk"), CharField()),
+                    ).values("group_id")[:1],
+                ),
+            )
+            .filter(
+                Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0),
+            )
+            if value is not None
+            else qs
+        )
+
+
+class CustomFieldFilterSet(FilterSet):
+    class Meta:
+        model = CustomField
+        fields = {
+            "id": ID_KWARGS,
+            "name": CHAR_KWARGS,
+        }
+
+
+@extend_schema_field(serializers.CharField)
+class CustomFieldsFilter(Filter):
+    def filter(self, qs, value):
+        if value:
+            fields_with_matching_selects = CustomField.objects.filter(
+                extra_data__icontains=value,
+            )
+            option_ids = []
+            if fields_with_matching_selects.count() > 0:
+                for field in fields_with_matching_selects:
+                    options = field.extra_data.get("select_options", [])
+                    for _, option in enumerate(options):
+                        if option.get("label").lower().find(value.lower()) != -1:
+                            option_ids.extend([option.get("id")])
+            return (
+                qs.filter(custom_fields__field__name__icontains=value)
+                | qs.filter(custom_fields__value_text__icontains=value)
+                | qs.filter(custom_fields__value_bool__icontains=value)
+                | qs.filter(custom_fields__value_int__icontains=value)
+                | qs.filter(custom_fields__value_float__icontains=value)
+                | qs.filter(custom_fields__value_date__icontains=value)
+                | qs.filter(custom_fields__value_url__icontains=value)
+                | qs.filter(custom_fields__value_monetary__icontains=value)
+                | qs.filter(custom_fields__value_document_ids__icontains=value)
+                | qs.filter(custom_fields__value_select__in=option_ids)
+            )
+        else:
+            return qs
+
+
+class MimeTypeFilter(Filter):
+    def filter(self, qs, value):
+        if value:
+            return qs.filter(mime_type__icontains=value)
+        else:
+            return qs
+
+
+class SelectField(serializers.CharField):
+    def __init__(self, custom_field: CustomField):
+        self._options = custom_field.extra_data["select_options"]
+        super().__init__(max_length=16)
+
+    def to_internal_value(self, data):
+        # If the supplied value is the option label instead of the ID
+        try:
+            data = next(
+                option.get("id")
+                for option in self._options
+                if option.get("label") == data
+            )
+        except StopIteration:
+            pass
+        return super().to_internal_value(data)
+
+
+def handle_validation_prefix(func: Callable):
+    """
+    Catch ValidationErrors raised by the wrapped function
+    and add a prefix to the exception detail to track what causes the exception,
+    similar to nested serializers.
+    """
+
+    def wrapper(*args, validation_prefix=None, **kwargs):
+        try:
+            return func(*args, **kwargs)
+        except serializers.ValidationError as e:
+            raise serializers.ValidationError({validation_prefix: e.detail})
+
+    # Update the signature to include the validation_prefix argument
+    old_sig = inspect.signature(func)
+    new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
+    new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
+
+    # Apply functools.wraps and manually set the new signature
+    functools.update_wrapper(wrapper, func)
+    wrapper.__signature__ = new_sig
+
+    return wrapper
+
+
+class CustomFieldQueryParser:
+    EXPR_BY_CATEGORY = {
+        "basic": ["exact", "in", "isnull", "exists"],
+        "string": [
+            "icontains",
+            "istartswith",
+            "iendswith",
+        ],
+        "arithmetic": [
+            "gt",
+            "gte",
+            "lt",
+            "lte",
+            "range",
+        ],
+        "containment": ["contains"],
+    }
+
+    SUPPORTED_EXPR_CATEGORIES = {
+        CustomField.FieldDataType.STRING: ("basic", "string"),
+        CustomField.FieldDataType.URL: ("basic", "string"),
+        CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
+        CustomField.FieldDataType.BOOL: ("basic",),
+        CustomField.FieldDataType.INT: ("basic", "arithmetic"),
+        CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
+        CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"),
+        CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
+        CustomField.FieldDataType.SELECT: ("basic",),
+    }
+
+    DATE_COMPONENTS = [
+        "year",
+        "iso_year",
+        "month",
+        "day",
+        "week",
+        "week_day",
+        "iso_week_day",
+        "quarter",
+    ]
+
+    def __init__(
+        self,
+        validation_prefix,
+        max_query_depth=10,
+        max_atom_count=20,
+    ) -> None:
+        """
+        A helper class that parses the query string into a `django.db.models.Q` for filtering
+        documents based on custom field values.
+
+        The syntax of the query expression is illustrated with the below pseudo code rules:
+        1. parse([`custom_field`, "exists", true]):
+            matches documents with Q(custom_fields__field=`custom_field`)
+        2. parse([`custom_field`, "exists", false]):
+            matches documents with ~Q(custom_fields__field=`custom_field`)
+        3. parse([`custom_field`, `op`, `value`]):
+            matches documents with
+            Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
+        4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
+            -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
+        5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
+            -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
+        6. parse(["NOT", `q`])
+            -> ~parse(`q`)
+
+        Args:
+            validation_prefix: Used to generate the ValidationError message.
+            max_query_depth: Limits the maximum nesting depth of queries.
+            max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
+
+        `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
+        complex SQL queries.
+        """
+        self._custom_fields: dict[int | str, CustomField] = {}
+        self._validation_prefix = validation_prefix
+        # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
+        self._model_serializer = serializers.ModelSerializer()
+        # Used for sanity check
+        self._max_query_depth = max_query_depth
+        self._max_atom_count = max_atom_count
+        self._current_depth = 0
+        self._atom_count = 0
+        # The set of annotations that we need to apply to the queryset
+        self._annotations = {}
+
+    def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
+        """
+        Parses the query string into a `django.db.models.Q`
+        and a set of annotations to be applied to the queryset.
+        """
+        try:
+            expr = json.loads(query)
+        except json.JSONDecodeError:
+            raise serializers.ValidationError(
+                {self._validation_prefix: [_("Value must be valid JSON.")]},
+            )
+        return (
+            self._parse_expr(expr, validation_prefix=self._validation_prefix),
+            self._annotations,
+        )
+
+    @handle_validation_prefix
+    def _parse_expr(self, expr) -> Q:
+        """
+        Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
+        """
+        with self._track_query_depth():
+            if isinstance(expr, list | tuple):
+                if len(expr) == 2:
+                    return self._parse_logical_expr(*expr)
+                elif len(expr) == 3:
+                    return self._parse_atom(*expr)
+            raise serializers.ValidationError(
+                [_("Invalid custom field query expression")],
+            )
+
+    @handle_validation_prefix
+    def _parse_expr_list(self, exprs) -> list[Q]:
+        """
+        Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
+        """
+        if not isinstance(exprs, list | tuple) or not exprs:
+            raise serializers.ValidationError(
+                [_("Invalid expression list. Must be nonempty.")],
+            )
+        return [
+            self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
+        ]
+
+    def _parse_logical_expr(self, op, args) -> Q:
+        """
+        Handles rule 4, 5, 6.
+        """
+        op_lower = op.lower()
+
+        if op_lower == "not":
+            return ~self._parse_expr(args, validation_prefix=1)
+
+        if op_lower == "and":
+            op_func = operator.and_
+        elif op_lower == "or":
+            op_func = operator.or_
+        else:
+            raise serializers.ValidationError(
+                {"0": [_("Invalid logical operator {op!r}").format(op=op)]},
+            )
+
+        qs = self._parse_expr_list(args, validation_prefix="1")
+        return functools.reduce(op_func, qs)
+
+    def _parse_atom(self, id_or_name, op, value) -> Q:
+        """
+        Handles rule 1, 2, 3.
+        """
+        # Guard against queries with too many conditions.
+        self._atom_count += 1
+        if self._atom_count > self._max_atom_count:
+            raise serializers.ValidationError(
+                [_("Maximum number of query conditions exceeded.")],
+            )
+
+        custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
+        op = self._validate_atom_op(custom_field, op, validation_prefix="1")
+        value = self._validate_atom_value(
+            custom_field,
+            op,
+            value,
+            validation_prefix="2",
+        )
+
+        # Needed because not all DB backends support Array __contains
+        if (
+            custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
+            and op == "contains"
+        ):
+            return self._parse_atom_doc_link_contains(custom_field, value)
+
+        value_field_name = CustomFieldInstance.get_value_field_name(
+            custom_field.data_type,
+        )
+        if (
+            custom_field.data_type == CustomField.FieldDataType.MONETARY
+            and op in self.EXPR_BY_CATEGORY["arithmetic"]
+        ):
+            value_field_name = "value_monetary_amount"
+        has_field = Q(custom_fields__field=custom_field)
+
+        # We need to use an annotation here because different atoms
+        # might be referring to different instances of custom fields.
+        annotation_name = f"_custom_field_filter_{len(self._annotations)}"
+
+        # Our special exists operator.
+        if op == "exists":
+            annotation = Count("custom_fields", filter=has_field)
+            # A Document should have > 0 match if it has this field, or 0 if doesn't.
+            query_op = "gt" if value else "exact"
+            query = Q(**{f"{annotation_name}__{query_op}": 0})
+        else:
+            # Check if 1) custom field name matches, and 2) value satisfies condition
+            field_filter = has_field & Q(
+                **{f"custom_fields__{value_field_name}__{op}": value},
+            )
+            # Annotate how many matching custom fields each document has
+            annotation = Count("custom_fields", filter=field_filter)
+            # Filter document by count
+            query = Q(**{f"{annotation_name}__gt": 0})
+
+        self._annotations[annotation_name] = annotation
+        return query
+
+    @handle_validation_prefix
+    def _get_custom_field(self, id_or_name):
+        """Get the CustomField instance by id or name."""
+        if id_or_name in self._custom_fields:
+            return self._custom_fields[id_or_name]
+
+        kwargs = (
+            {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
+        )
+        try:
+            custom_field = CustomField.objects.get(**kwargs)
+        except CustomField.DoesNotExist:
+            raise serializers.ValidationError(
+                [_("{name!r} is not a valid custom field.").format(name=id_or_name)],
+            )
+        self._custom_fields[custom_field.id] = custom_field
+        self._custom_fields[custom_field.name] = custom_field
+        return custom_field
+
+    @staticmethod
+    def _split_op(full_op):
+        *prefix, op = str(full_op).rsplit("__", maxsplit=1)
+        prefix = prefix[0] if prefix else None
+        return prefix, op
+
+    @handle_validation_prefix
+    def _validate_atom_op(self, custom_field, raw_op):
+        """Check if the `op` is compatible with the type of the custom field."""
+        prefix, op = self._split_op(raw_op)
+
+        # Check if the operator is supported for the current data_type.
+        supported = False
+        for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
+            if op in self.EXPR_BY_CATEGORY[category]:
+                supported = True
+                break
+
+        # Check prefix
+        if prefix is not None:
+            if (
+                prefix in self.DATE_COMPONENTS
+                and custom_field.data_type == CustomField.FieldDataType.DATE
+            ):
+                pass  # ok - e.g., "year__exact" for date field
+            else:
+                supported = False  # anything else is invalid
+
+        if not supported:
+            raise serializers.ValidationError(
+                [
+                    _("{data_type} does not support query expr {expr!r}.").format(
+                        data_type=custom_field.data_type,
+                        expr=raw_op,
+                    ),
+                ],
+            )
+
+        return raw_op
+
+    def _get_serializer_field(self, custom_field, full_op):
+        """Return a serializers.Field for value validation."""
+        prefix, op = self._split_op(full_op)
+        field = None
+
+        if op in ("isnull", "exists"):
+            # `isnull` takes either True or False regardless of the data_type.
+            field = serializers.BooleanField()
+        elif (
+            custom_field.data_type == CustomField.FieldDataType.DATE
+            and prefix in self.DATE_COMPONENTS
+        ):
+            # DateField admits queries in the form of `year__exact`, etc. These take integers.
+            field = serializers.IntegerField()
+        elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
+            # We can be more specific here and make sure the value is a list.
+            field = serializers.ListField(child=serializers.IntegerField())
+        elif custom_field.data_type == CustomField.FieldDataType.SELECT:
+            # We use this custom field to permit SELECT option names.
+            field = SelectField(custom_field)
+        elif custom_field.data_type == CustomField.FieldDataType.URL:
+            # For URL fields we don't need to be strict about validation (e.g., for istartswith).
+            field = serializers.CharField()
+        else:
+            # The general case: inferred from the corresponding field in CustomFieldInstance.
+            value_field_name = CustomFieldInstance.get_value_field_name(
+                custom_field.data_type,
+            )
+            model_field = CustomFieldInstance._meta.get_field(value_field_name)
+            field_name = model_field.deconstruct()[0]
+            field_class, field_kwargs = self._model_serializer.build_standard_field(
+                field_name,
+                model_field,
+            )
+            field = field_class(**field_kwargs)
+            field.allow_null = False
+
+            # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
+            # See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
+            if isinstance(field, serializers.CharField):
+                field.allow_blank = True
+
+        if op == "in":
+            # `in` takes a list of values.
+            field = serializers.ListField(child=field, allow_empty=False)
+        elif op == "range":
+            # `range` takes a list of values, i.e., [start, end].
+            field = serializers.ListField(
+                child=field,
+                min_length=2,
+                max_length=2,
+            )
+
+        return field
+
+    @handle_validation_prefix
+    def _validate_atom_value(self, custom_field, op, value):
+        """Check if `value` is valid for the custom field and `op`. Returns the validated value."""
+        serializer_field = self._get_serializer_field(custom_field, op)
+        return serializer_field.run_validation(value)
+
+    def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
+        """
+        Handles document link `contains` in a way that is supported by all DB backends.
+        """
+
+        # If the value is an empty set,
+        # this is trivially true for any document with not null document links.
+        if not value:
+            return Q(
+                custom_fields__field=custom_field,
+                custom_fields__value_document_ids__isnull=False,
+            )
+
+        # First we look up reverse links from the requested documents.
+        links = CustomFieldInstance.objects.filter(
+            document_id__in=value,
+            field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
+        )
+
+        # Check if any of the requested IDs are missing.
+        missing_ids = set(value) - set(link.document_id for link in links)
+        if missing_ids:
+            # The result should be an empty set in this case.
+            return Q(id__in=[])
+
+        # Take the intersection of the reverse links - this should be what we are looking for.
+        document_ids_we_want = functools.reduce(
+            operator.and_,
+            (set(link.value_document_ids) for link in links),
+        )
+
+        return Q(id__in=document_ids_we_want)
+
+    @contextmanager
+    def _track_query_depth(self):
+        # guard against queries that are too deeply nested
+        self._current_depth += 1
+        if self._current_depth > self._max_query_depth:
+            raise serializers.ValidationError([_("Maximum nesting depth exceeded.")])
+        try:
+            yield
+        finally:
+            self._current_depth -= 1
+
+
+@extend_schema_field(serializers.CharField)
+class CustomFieldQueryFilter(Filter):
+    def __init__(self, validation_prefix):
+        """
+        A filter that filters documents based on custom field name and value.
+
+        Args:
+            validation_prefix: Used to generate the ValidationError message.
+        """
+        super().__init__()
+        self._validation_prefix = validation_prefix
+
+    def filter(self, qs, value):
+        if not value:
+            return qs
+
+        parser = CustomFieldQueryParser(
+            self._validation_prefix,
+            max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH,
+            max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS,
+        )
+        q, annotations = parser.parse(value)
+
+        return qs.annotate(**annotations).filter(q)
+
+
+class DocumentFilterSet(FilterSet):
+    is_tagged = BooleanFilter(
+        label="Is tagged",
+        field_name="tags",
+        lookup_expr="isnull",
+        exclude=True,
+    )
+
+    tags__id__all = ObjectFilter(field_name="tags")
+
+    tags__id__none = ObjectFilter(field_name="tags", exclude=True)
+
+    tags__id__in = ObjectFilter(field_name="tags", in_list=True)
+
+    correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True)
+
+    document_type__id__none = ObjectFilter(field_name="document_type", exclude=True)
+
+    storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True)
+
+    is_in_inbox = InboxFilter()
+
+    title_content = TitleContentFilter()
+
+    owner__id__none = ObjectFilter(field_name="owner", exclude=True)
+
+    custom_fields__icontains = CustomFieldsFilter()
+
+    custom_fields__id__all = ObjectFilter(field_name="custom_fields__field")
+
+    custom_fields__id__none = ObjectFilter(
+        field_name="custom_fields__field",
+        exclude=True,
+    )
+
+    custom_fields__id__in = ObjectFilter(
+        field_name="custom_fields__field",
+        in_list=True,
+    )
+
+    has_custom_fields = BooleanFilter(
+        label="Has custom field",
+        field_name="custom_fields",
+        lookup_expr="isnull",
+        exclude=True,
+    )
+
+    custom_field_query = CustomFieldQueryFilter("custom_field_query")
+
+    shared_by__id = SharedByUser()
+
+    mime_type = MimeTypeFilter()
+
+    class Meta:
+        model = Document
+        fields = {
+            "id": ID_KWARGS,
+            "title": CHAR_KWARGS,
+            "content": CHAR_KWARGS,
+            "archive_serial_number": INT_KWARGS,
+            "created": DATE_KWARGS,
+            "added": DATE_KWARGS,
+            "modified": DATE_KWARGS,
+            "original_filename": CHAR_KWARGS,
+            "checksum": CHAR_KWARGS,
+            "correspondent": ["isnull"],
+            "correspondent__id": ID_KWARGS,
+            "correspondent__name": CHAR_KWARGS,
+            "tags__id": ID_KWARGS,
+            "tags__name": CHAR_KWARGS,
+            "document_type": ["isnull"],
+            "document_type__id": ID_KWARGS,
+            "document_type__name": CHAR_KWARGS,
+            "storage_path": ["isnull"],
+            "storage_path__id": ID_KWARGS,
+            "storage_path__name": CHAR_KWARGS,
+            "owner": ["isnull"],
+            "owner__id": ID_KWARGS,
+            "custom_fields": ["icontains"],
+        }
+
+
+class ShareLinkFilterSet(FilterSet):
+    class Meta:
+        model = ShareLink
+        fields = {
+            "created": DATE_KWARGS,
+            "expiration": DATE_KWARGS,
+        }
+
+
+class PaperlessTaskFilterSet(FilterSet):
+    acknowledged = BooleanFilter(
+        label="Acknowledged",
+        field_name="acknowledged",
+    )
+
+    class Meta:
+        model = PaperlessTask
+        fields = {
+            "type": ["exact"],
+            "task_name": ["exact"],
+            "status": ["exact"],
+        }
+
+
+class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
+    """
+    A filter backend that limits results to those where the requesting user
+    has read object level permissions, owns the objects, or objects without
+    an owner (for backwards compat)
+    """
+
+    def filter_queryset(self, request, queryset, view):
+        objects_with_perms = super().filter_queryset(request, queryset, view)
+        objects_owned = queryset.filter(owner=request.user)
+        objects_unowned = queryset.filter(owner__isnull=True)
+        return objects_with_perms | objects_owned | objects_unowned
+
+
+class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter):
+    """
+    A filter backend that limits results to those where the requesting user
+    owns the objects or objects without an owner (for backwards compat)
+    """
+
+    def filter_queryset(self, request, queryset, view):
+        if request.user.is_superuser:
+            return queryset
+        objects_owned = queryset.filter(owner=request.user)
+        objects_unowned = queryset.filter(owner__isnull=True)
+        return objects_owned | objects_unowned
+
+
+class DocumentsOrderingFilter(OrderingFilter):
+    field_name = "ordering"
+    prefix = "custom_field_"
+
+    def filter_queryset(self, request, queryset, view):
+        param = request.query_params.get("ordering")
+        if param and self.prefix in param:
+            custom_field_id = int(param.split(self.prefix)[1])
+            try:
+                field = CustomField.objects.get(pk=custom_field_id)
+            except CustomField.DoesNotExist:
+                raise serializers.ValidationError(
+                    {self.prefix + str(custom_field_id): [_("Custom field not found")]},
+                )
+
+            annotation = None
+            match field.data_type:
+                case CustomField.FieldDataType.STRING:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_text")[:1],
+                    )
+                case CustomField.FieldDataType.INT:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_int")[:1],
+                    )
+                case CustomField.FieldDataType.FLOAT:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_float")[:1],
+                    )
+                case CustomField.FieldDataType.DATE:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_date")[:1],
+                    )
+                case CustomField.FieldDataType.MONETARY:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_monetary_amount")[:1],
+                    )
+                case CustomField.FieldDataType.SELECT:
+                    # Select options are a little more complicated since the value is the id of the option, not
+                    # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a
+                    # case statement for each option, setting the value to the index of the option in a list
+                    # sorted by label, and then summing the results to give a single value for the annotation
+
+                    select_options = sorted(
+                        field.extra_data.get("select_options", []),
+                        key=lambda x: x.get("label"),
+                    )
+                    whens = [
+                        When(
+                            custom_fields__field_id=custom_field_id,
+                            custom_fields__value_select=option.get("id"),
+                            then=Value(idx, output_field=IntegerField()),
+                        )
+                        for idx, option in enumerate(select_options)
+                    ]
+                    whens.append(
+                        When(
+                            custom_fields__field_id=custom_field_id,
+                            custom_fields__value_select__isnull=True,
+                            then=Value(
+                                len(select_options),
+                                output_field=IntegerField(),
+                            ),
+                        ),
+                    )
+                    annotation = Sum(
+                        Case(
+                            *whens,
+                            default=Value(0),
+                            output_field=IntegerField(),
+                        ),
+                    )
+                case CustomField.FieldDataType.DOCUMENTLINK:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_document_ids")[:1],
+                    )
+                case CustomField.FieldDataType.URL:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_url")[:1],
+                    )
+                case CustomField.FieldDataType.BOOL:
+                    annotation = Subquery(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ).values("value_bool")[:1],
+                    )
+
+            if not annotation:
+                # Only happens if a new data type is added and not handled here
+                raise ValueError("Invalid custom field data type")
+
+            queryset = (
+                queryset.annotate(
+                    # We need to annotate the queryset with the custom field value
+                    custom_field_value=annotation,
+                    # We also need to annotate the queryset with a boolean for sorting whether the field exists
+                    has_field=Exists(
+                        CustomFieldInstance.objects.filter(
+                            document_id=OuterRef("id"),
+                            field_id=custom_field_id,
+                        ),
+                    ),
+                )
+                .order_by(
+                    "-has_field",
+                    param.replace(
+                        self.prefix + str(custom_field_id),
+                        "custom_field_value",
+                    ),
+                )
+                .distinct()
+            )
+
+        return super().filter_queryset(request, queryset, view)
index 93c26aceda7d4b9c7ed1629dfda55ccee1388c1a..d960225f20a84495f4d65082996c697beed7ee01 100644 (file)
@@ -89,17 +89,6 @@ from rest_framework.viewsets import ModelViewSet
 from rest_framework.viewsets import ReadOnlyModelViewSet
 from rest_framework.viewsets import ViewSet
 
-from documents.filters import CorrespondentFilterSet
-from documents.filters import CustomFieldFilterSet
-from documents.filters import DocumentFilterSet
-from documents.filters import DocumentsOrderingFilter
-from documents.filters import DocumentTypeFilterSet
-from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
-from documents.filters import ObjectOwnedPermissionsFilter
-from documents.filters import PaperlessTaskFilterSet
-from documents.filters import ShareLinkFilterSet
-from documents.filters import StoragePathFilterSet
-from documents.filters import TagFilterSet
 from documents.schema import generate_object_with_permissions_schema
 from documents.signals import document_updated
 from documents.templating.filepath import validate_filepath_template_and_render
@@ -129,7 +118,18 @@ from paperless.data_models import ConsumableDocument
 from paperless.data_models import DocumentMetadataOverrides
 from paperless.data_models import DocumentSource
 from paperless.db import GnuPG
+from paperless.filters import CorrespondentFilterSet
+from paperless.filters import CustomFieldFilterSet
+from paperless.filters import DocumentFilterSet
+from paperless.filters import DocumentsOrderingFilter
+from paperless.filters import DocumentTypeFilterSet
 from paperless.filters import GroupFilterSet
+from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
+from paperless.filters import ObjectOwnedPermissionsFilter
+from paperless.filters import PaperlessTaskFilterSet
+from paperless.filters import ShareLinkFilterSet
+from paperless.filters import StoragePathFilterSet
+from paperless.filters import TagFilterSet
 from paperless.filters import UserFilterSet
 from paperless.index import DelayedQuery
 from paperless.mail import send_email
index 62a25c60c87fd3ae1b1868e57eff91bb2adbf9a7..202e3b34714d78117c45b98e8e2b5b78617369f1 100644 (file)
@@ -17,7 +17,7 @@ from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.viewsets import ModelViewSet
 
-from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
+from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter
 from paperless.permissions import PaperlessObjectPermissions
 from paperless.views import PassUserMixin
 from paperless.views import StandardPagination