+import functools
+import inspect
+import json
+import operator
+from contextlib import contextmanager
+from typing import Callable
+from typing import Union
+
from django.contrib.contenttypes.models import ContentType
from django.db.models import CharField
from django.db.models import Count
from django.db.models import OuterRef
from django.db.models import Q
from django.db.models.functions import Cast
+from django.utils.translation import gettext_lazy as _
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
from guardian.utils import get_group_obj_perms_model
from guardian.utils import get_user_obj_perms_model
+from rest_framework import serializers
from rest_framework_guardian.filters import ObjectPermissionsFilter
from documents.models import Correspondent
from documents.models import CustomField
+from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import Log
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
+from paperless import settings
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
ID_KWARGS = ["in", "exact"]
return qs
+class SelectField(serializers.IntegerField):
+ def __init__(self, custom_field: CustomField):
+ self._options = custom_field.extra_data["select_options"]
+ super().__init__(min_value=0, max_value=len(self._options))
+
+ def to_internal_value(self, data):
+ if not isinstance(data, int):
+ # If the supplied value is not an integer,
+ # we will try to map it to an option index.
+ try:
+ data = self._options.index(data)
+ except ValueError:
+ pass
+ return super().to_internal_value(data)
+
+
+def handle_validation_prefix(func: Callable):
+ """
+ Catch ValidationErrors raised by the wrapped function
+ and add a prefix to the exception detail to track what causes the exception,
+ similar to nested serializers.
+ """
+
+ def wrapper(*args, validation_prefix=None, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except serializers.ValidationError as e:
+ raise serializers.ValidationError({validation_prefix: e.detail})
+
+ # Update the signature to include the validation_prefix argument
+ old_sig = inspect.signature(func)
+ new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY)
+ new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param])
+
+ # Apply functools.wraps and manually set the new signature
+ functools.update_wrapper(wrapper, func)
+ wrapper.__signature__ = new_sig
+
+ return wrapper
+
+
+class CustomFieldLookupParser:
+ EXPR_BY_CATEGORY = {
+ "basic": ["exact", "in", "isnull", "exists"],
+ "string": [
+ "iexact",
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "regex",
+ "iregex",
+ ],
+ "arithmetic": [
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "range",
+ ],
+ "containment": ["contains"],
+ }
+
+ # These string lookup expressions are problematic. We shall disable
+ # them by default unless the user explicitly opts in.
+ STR_EXPR_DISABLED_BY_DEFAULT = [
+ # SQLite: is case-sensitive outside the ASCII range
+ "iexact",
+ # SQLite: behaves the same as icontains
+ "contains",
+ # SQLite: behaves the same as istartswith
+ "startswith",
+ # SQLite: behaves the same as iendswith
+ "endswith",
+ # Syntax depends on database backends, can be exploited for ReDoS
+ "regex",
+ # Syntax depends on database backends, can be exploited for ReDoS
+ "iregex",
+ ]
+
+ SUPPORTED_EXPR_CATEGORIES = {
+ CustomField.FieldDataType.STRING: ("basic", "string"),
+ CustomField.FieldDataType.URL: ("basic", "string"),
+ CustomField.FieldDataType.DATE: ("basic", "arithmetic"),
+ CustomField.FieldDataType.BOOL: ("basic",),
+ CustomField.FieldDataType.INT: ("basic", "arithmetic"),
+ CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"),
+ CustomField.FieldDataType.MONETARY: ("basic", "string"),
+ CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"),
+ CustomField.FieldDataType.SELECT: ("basic",),
+ }
+
+ DATE_COMPONENTS = [
+ "year",
+ "iso_year",
+ "month",
+ "day",
+ "week",
+ "week_day",
+ "iso_week_day",
+ "quarter",
+ ]
+
+ def __init__(
+ self,
+ validation_prefix,
+ max_query_depth=10,
+ max_atom_count=20,
+ ) -> None:
+ """
+ A helper class that parses the query string into a `django.db.models.Q` for filtering
+ documents based on custom field values.
+
+ The syntax of the query expression is illustrated with the below pseudo code rules:
+ 1. parse([`custom_field`, "exists", true]):
+ matches documents with Q(custom_fields__field=`custom_field`)
+ 2. parse([`custom_field`, "exists", false]):
+ matches documents with ~Q(custom_fields__field=`custom_field`)
+ 3. parse([`custom_field`, `op`, `value`]):
+ matches documents with
+ Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`)
+ 4. parse(["AND", [`q0`, `q1`, ..., `qn`]])
+ -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`)
+ 5. parse(["OR", [`q0`, `q1`, ..., `qn`]])
+ -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`)
+ 6. parse(["NOT", `q`])
+ -> ~parse(`q`)
+
+ Args:
+ validation_prefix: Used to generate the ValidationError message.
+ max_query_depth: Limits the maximum nesting depth of queries.
+ max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query.
+
+ `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily
+ complex SQL queries.
+ """
+ self._custom_fields: dict[Union[int, str], CustomField] = {}
+ self._validation_prefix = validation_prefix
+ # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field.
+ self._model_serializer = serializers.ModelSerializer()
+ # Used for sanity check
+ self._max_query_depth = max_query_depth
+ self._max_atom_count = max_atom_count
+ self._current_depth = 0
+ self._atom_count = 0
+ # The set of annotations that we need to apply to the queryset
+ self._annotations = {}
+
+ def parse(self, query: str) -> tuple[Q, dict[str, Count]]:
+ """
+ Parses the query string into a `django.db.models.Q`
+ and a set of annotations to be applied to the queryset.
+ """
+ try:
+ expr = json.loads(query)
+ except json.JSONDecodeError:
+ raise serializers.ValidationError(
+ {self._validation_prefix: [_("Value must be valid JSON.")]},
+ )
+ return (
+ self._parse_expr(expr, validation_prefix=self._validation_prefix),
+ self._annotations,
+ )
+
+ @handle_validation_prefix
+ def _parse_expr(self, expr) -> Q:
+ """
+ Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr.
+ """
+ with self._track_query_depth():
+ if isinstance(expr, (list, tuple)):
+ if len(expr) == 2:
+ return self._parse_logical_expr(*expr)
+ elif len(expr) == 3:
+ return self._parse_atom(*expr)
+ raise serializers.ValidationError(
+ [_("Invalid custom field lookup expression")],
+ )
+
+ @handle_validation_prefix
+ def _parse_expr_list(self, exprs) -> list[Q]:
+ """
+ Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5.
+ """
+ if not isinstance(exprs, (list, tuple)) or not exprs:
+ raise serializers.ValidationError(
+ [_("Invalid expression list. Must be nonempty.")],
+ )
+ return [
+ self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs)
+ ]
+
+ def _parse_logical_expr(self, op, args) -> Q:
+ """
+ Handles rule 4, 5, 6.
+ """
+ op_lower = op.lower()
+
+ if op_lower == "not":
+ return ~self._parse_expr(args, validation_prefix=1)
+
+ if op_lower == "and":
+ op_func = operator.and_
+ elif op_lower == "or":
+ op_func = operator.or_
+ else:
+ raise serializers.ValidationError(
+ {"0": [_("Invalid logical operator {op!r}").format(op=op)]},
+ )
+
+ qs = self._parse_expr_list(args, validation_prefix="1")
+ return functools.reduce(op_func, qs)
+
+ def _parse_atom(self, id_or_name, op, value) -> Q:
+ """
+ Handles rule 1, 2, 3.
+ """
+ # Guard against queries with too many conditions.
+ self._atom_count += 1
+ if self._atom_count > self._max_atom_count:
+ raise serializers.ValidationError(
+ [
+ _(
+ "Maximum number of query conditions exceeded. You can raise "
+ "the limit by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_ATOMS "
+ "in your configuration file.",
+ ),
+ ],
+ )
+
+ custom_field = self._get_custom_field(id_or_name, validation_prefix="0")
+ op = self._validate_atom_op(custom_field, op, validation_prefix="1")
+ value = self._validate_atom_value(
+ custom_field,
+ op,
+ value,
+ validation_prefix="2",
+ )
+
+ # Needed because not all DB backends support Array __contains
+ if (
+ custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK
+ and op == "contains"
+ ):
+ return self._parse_atom_doc_link_contains(custom_field, value)
+
+ value_field_name = CustomFieldInstance.get_value_field_name(
+ custom_field.data_type,
+ )
+ has_field = Q(custom_fields__field=custom_field)
+
+ # Our special exists operator.
+ if op == "exists":
+ field_filter = has_field if value else ~has_field
+ else:
+ field_filter = has_field & Q(
+ **{f"custom_fields__{value_field_name}__{op}": value},
+ )
+
+ # We need to use an annotation here because different atoms
+ # might be referring to different instances of custom fields.
+ annotation_name = f"_custom_field_filter_{len(self._annotations)}"
+ self._annotations[annotation_name] = Count("custom_fields", filter=field_filter)
+
+ return Q(**{f"{annotation_name}__gt": 0})
+
+ @handle_validation_prefix
+ def _get_custom_field(self, id_or_name):
+ """Get the CustomField instance by id or name."""
+ if id_or_name in self._custom_fields:
+ return self._custom_fields[id_or_name]
+
+ kwargs = (
+ {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name}
+ )
+ try:
+ custom_field = CustomField.objects.get(**kwargs)
+ except CustomField.DoesNotExist:
+ raise serializers.ValidationError(
+ [_("{name!r} is not a valid custom field.").format(name=id_or_name)],
+ )
+ self._custom_fields[custom_field.id] = custom_field
+ self._custom_fields[custom_field.name] = custom_field
+ return custom_field
+
+ @staticmethod
+ def _split_op(full_op):
+ *prefix, op = str(full_op).rsplit("__", maxsplit=1)
+ prefix = prefix[0] if prefix else None
+ return prefix, op
+
+ @handle_validation_prefix
+ def _validate_atom_op(self, custom_field, raw_op):
+ """Check if the `op` is compatible with the type of the custom field."""
+ prefix, op = self._split_op(raw_op)
+
+ # Check if the operator is supported for the current data_type.
+ supported = False
+ for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]:
+ if (
+ category == "string"
+ and op in self.STR_EXPR_DISABLED_BY_DEFAULT
+ and op not in settings.CUSTOM_FIELD_LOOKUP_OPT_IN
+ ):
+ raise serializers.ValidationError(
+ [
+ _(
+ "{expr!r} is disabled by default because it does not "
+ "behave consistently across database backends, or can "
+ "cause security risks. If you understand the implications "
+ "you may enabled it by adding it to "
+ "`PAPERLESS_CUSTOM_FIELD_LOOKUP_OPT_IN`.",
+ ).format(expr=op),
+ ],
+ )
+ if op in self.EXPR_BY_CATEGORY[category]:
+ supported = True
+ break
+
+ # Check prefix
+ if prefix is not None:
+ if (
+ prefix in self.DATE_COMPONENTS
+ and custom_field.data_type == CustomField.FieldDataType.DATE
+ ):
+ pass # ok - e.g., "year__exact" for date field
+ else:
+ supported = False # anything else is invalid
+
+ if not supported:
+ raise serializers.ValidationError(
+ [
+ _("{data_type} does not support lookup expr {expr!r}.").format(
+ data_type=custom_field.data_type,
+ expr=raw_op,
+ ),
+ ],
+ )
+
+ return raw_op
+
+ def _get_serializer_field(self, custom_field, full_op):
+ """Return a serializers.Field for value validation."""
+ prefix, op = self._split_op(full_op)
+ field = None
+
+ if op in ("isnull", "exists"):
+ # `isnull` takes either True or False regardless of the data_type.
+ field = serializers.BooleanField()
+ elif (
+ custom_field.data_type == CustomField.FieldDataType.DATE
+ and prefix in self.DATE_COMPONENTS
+ ):
+ # DateField admits lookups in the form of `year__exact`, etc. These take integers.
+ field = serializers.IntegerField()
+ elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
+ # We can be more specific here and make sure the value is a list.
+ field = serializers.ListField(child=serializers.IntegerField())
+ elif custom_field.data_type == CustomField.FieldDataType.SELECT:
+ # We use this custom field to permit SELECT option names.
+ field = SelectField(custom_field)
+ elif custom_field.data_type == CustomField.FieldDataType.URL:
+ # For URL fields we don't need to be strict about validation (e.g., for istartswith).
+ field = serializers.CharField()
+ else:
+ # The general case: inferred from the corresponding field in CustomFieldInstance.
+ value_field_name = CustomFieldInstance.get_value_field_name(
+ custom_field.data_type,
+ )
+ model_field = CustomFieldInstance._meta.get_field(value_field_name)
+ field_name = model_field.deconstruct()[0]
+ field_class, field_kwargs = self._model_serializer.build_standard_field(
+ field_name,
+ model_field,
+ )
+ field = field_class(**field_kwargs)
+ field.allow_null = False
+
+ # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation.
+ # See https://github.com/paperless-ngx/paperless-ngx/issues/7361.
+ if isinstance(field, serializers.CharField):
+ field.allow_blank = True
+
+ if op == "in":
+ # `in` takes a list of values.
+ field = serializers.ListField(child=field, allow_empty=False)
+ elif op == "range":
+ # `range` takes a list of values, i.e., [start, end].
+ field = serializers.ListField(
+ child=field,
+ min_length=2,
+ max_length=2,
+ )
+
+ return field
+
+ @handle_validation_prefix
+ def _validate_atom_value(self, custom_field, op, value):
+ """Check if `value` is valid for the custom field and `op`. Returns the validated value."""
+ serializer_field = self._get_serializer_field(custom_field, op)
+ return serializer_field.run_validation(value)
+
+ def _parse_atom_doc_link_contains(self, custom_field, value) -> Q:
+ """
+ Handles document link `contains` in a way that is supported by all DB backends.
+ """
+
+ # If the value is an empty set,
+ # this is trivially true for any document with not null document links.
+ if not value:
+ return Q(
+ custom_fields__field=custom_field,
+ custom_fields__value_document_ids__isnull=False,
+ )
+
+ # First we lookup reverse links from the requested documents.
+ links = CustomFieldInstance.objects.filter(
+ document_id__in=value,
+ field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
+ )
+
+ # Check if any of the requested IDs are missing.
+ missing_ids = set(value) - set(link.document_id for link in links)
+ if missing_ids:
+ # The result should be an empty set in this case.
+ return Q(id__in=[])
+
+ # Take the intersection of the reverse links - this should be what we are looking for.
+ document_ids_we_want = functools.reduce(
+ operator.and_,
+ (set(link.value_document_ids) for link in links),
+ )
+
+ return Q(id__in=document_ids_we_want)
+
+ @contextmanager
+ def _track_query_depth(self):
+ # guard against queries that are too deeply nested
+ self._current_depth += 1
+ if self._current_depth > self._max_query_depth:
+ raise serializers.ValidationError(
+ [
+ _(
+ "Maximum nesting depth exceeded. You can raise the limit "
+ "by setting PAPERLESS_CUSTOM_FIELD_LOOKUP_MAX_DEPTH in "
+ "your configuration file.",
+ ),
+ ],
+ )
+ try:
+ yield
+ finally:
+ self._current_depth -= 1
+
+
+class CustomFieldLookupFilter(Filter):
+ def __init__(self, validation_prefix):
+ """
+ A filter that filters documents based on custom field name and value.
+
+ Args:
+ validation_prefix: Used to generate the ValidationError message.
+ """
+ super().__init__()
+ self._validation_prefix = validation_prefix
+
+ def filter(self, qs, value):
+ if not value:
+ return qs
+
+ parser = CustomFieldLookupParser(
+ self._validation_prefix,
+ max_query_depth=settings.CUSTOM_FIELD_LOOKUP_MAX_DEPTH,
+ max_atom_count=settings.CUSTOM_FIELD_LOOKUP_MAX_ATOMS,
+ )
+ q, annotations = parser.parse(value)
+
+ return qs.annotate(**annotations).filter(q)
+
+
class DocumentFilterSet(FilterSet):
is_tagged = BooleanFilter(
label="Is tagged",
exclude=True,
)
+ custom_field_lookup = CustomFieldLookupFilter("custom_field_lookup")
+
shared_by__id = SharedByUser()
class Meta:
--- /dev/null
+import json
+import re
+from datetime import date
+from typing import Callable
+from unittest.mock import Mock
+from urllib.parse import quote
+
+import pytest
+from django.contrib.auth.models import User
+from rest_framework.test import APITestCase
+
+from documents.models import CustomField
+from documents.models import Document
+from documents.serialisers import DocumentSerializer
+from documents.tests.utils import DirectoriesMixin
+from paperless import settings
+
+
+class DocumentWrapper:
+ """
+ Allows Pythonic access to the custom fields associated with the wrapped document.
+ """
+
+ def __init__(self, document: Document) -> None:
+ self._document = document
+
+ def __contains__(self, custom_field: str) -> bool:
+ return self._document.custom_fields.filter(field__name=custom_field).exists()
+
+ def __getitem__(self, custom_field: str):
+ return self._document.custom_fields.get(field__name=custom_field).value
+
+
+def string_expr_opted_in(op):
+ return op in settings.CUSTOM_FIELD_LOOKUP_OPT_IN
+
+
+class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
+ def setUp(self):
+ super().setUp()
+
+ self.user = User.objects.create_superuser(username="temp_admin")
+ self.client.force_authenticate(user=self.user)
+
+ # Create one custom field per type. The fields are called f"{type}_field".
+ self.custom_fields = {}
+ for data_type in CustomField.FieldDataType.values:
+ name = data_type + "_field"
+ self.custom_fields[name] = CustomField.objects.create(
+ name=name,
+ data_type=data_type,
+ )
+
+ # Add some options to the select_field
+ select = self.custom_fields["select_field"]
+ select.extra_data = {"select_options": ["A", "B", "C"]}
+ select.save()
+
+ # Now we will create some test documents
+ self.documents = []
+
+ # CustomField.FieldDataType.STRING
+ self._create_document(string_field=None)
+ self._create_document(string_field="")
+ self._create_document(string_field="paperless")
+ self._create_document(string_field="Paperless")
+ self._create_document(string_field="PAPERLESS")
+ self._create_document(string_field="pointless")
+ self._create_document(string_field="pointy")
+
+ # CustomField.FieldDataType.URL
+ self._create_document(url_field=None)
+ self._create_document(url_field="")
+ self._create_document(url_field="https://docs.paperless-ngx.com/")
+ self._create_document(url_field="https://www.django-rest-framework.org/")
+ self._create_document(url_field="http://example.com/")
+
+ # A document to check if the filter correctly associates field names with values.
+ # E.g., ["url_field", "exact", "https://docs.paperless-ngx.com/"] should not
+ # yield this document.
+ self._create_document(
+ string_field="https://docs.paperless-ngx.com/",
+ url_field="http://example.com/",
+ )
+
+ # CustomField.FieldDataType.DATE
+ self._create_document(date_field=None)
+ self._create_document(date_field=date(2023, 8, 22))
+ self._create_document(date_field=date(2024, 8, 22))
+ self._create_document(date_field=date(2024, 11, 15))
+
+ # CustomField.FieldDataType.BOOL
+ self._create_document(boolean_field=None)
+ self._create_document(boolean_field=True)
+ self._create_document(boolean_field=False)
+
+ # CustomField.FieldDataType.INT
+ self._create_document(integer_field=None)
+ self._create_document(integer_field=-1)
+ self._create_document(integer_field=0)
+ self._create_document(integer_field=1)
+
+ # CustomField.FieldDataType.FLOAT
+ self._create_document(float_field=None)
+ self._create_document(float_field=-1e9)
+ self._create_document(float_field=0.05)
+ self._create_document(float_field=270.0)
+
+ # CustomField.FieldDataType.MONETARY
+ self._create_document(monetary_field=None)
+ self._create_document(monetary_field="USD100.00")
+ self._create_document(monetary_field="USD1.00")
+ self._create_document(monetary_field="EUR50.00")
+
+ # CustomField.FieldDataType.DOCUMENTLINK
+ self._create_document(documentlink_field=None)
+ self._create_document(documentlink_field=[])
+ self._create_document(
+ documentlink_field=[
+ self.documents[0].id,
+ self.documents[1].id,
+ self.documents[2].id,
+ ],
+ )
+ self._create_document(
+ documentlink_field=[self.documents[4].id, self.documents[5].id],
+ )
+
+ # CustomField.FieldDataType.SELECT
+ self._create_document(select_field=None)
+ self._create_document(select_field=0)
+ self._create_document(select_field=1)
+ self._create_document(select_field=2)
+
+ def _create_document(self, **kwargs):
+ title = str(kwargs)
+ document = Document.objects.create(
+ title=title,
+ checksum=title,
+ archive_serial_number=len(self.documents) + 1,
+ )
+ data = {
+ "custom_fields": [
+ {"field": self.custom_fields[name].id, "value": value}
+ for name, value in kwargs.items()
+ ],
+ }
+ serializer = DocumentSerializer(
+ document,
+ data=data,
+ partial=True,
+ context={"request": Mock()},
+ )
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+ self.documents.append(document)
+ return document
+
+ def _assert_query_match_predicate(
+ self,
+ query: list,
+ reference_predicate: Callable[[DocumentWrapper], bool],
+ match_nothing_ok=False,
+ ):
+ """
+ Checks the results of the query against a callable reference predicate.
+ """
+ reference_document_ids = [
+ document.id
+ for document in self.documents
+ if reference_predicate(DocumentWrapper(document))
+ ]
+ # First sanity check our test cases
+ if not match_nothing_ok:
+ self.assertTrue(
+ reference_document_ids,
+ msg="Bad test case - should match at least one document.",
+ )
+ self.assertNotEqual(
+ len(reference_document_ids),
+ len(self.documents),
+ msg="Bad test case - should not match all documents.",
+ )
+
+ # Now make the API call.
+ query_string = quote(json.dumps(query), safe="")
+ response = self.client.get(
+ "/api/documents/?"
+ + "&".join(
+ (
+ f"custom_field_lookup={query_string}",
+ "ordering=archive_serial_number",
+ "page=1",
+ f"page_size={len(self.documents)}",
+ "truncate_content=true",
+ ),
+ ),
+ )
+ self.assertEqual(response.status_code, 200, msg=str(response.json()))
+ response_document_ids = [
+ document["id"] for document in response.json()["results"]
+ ]
+ self.assertEqual(reference_document_ids, response_document_ids)
+
+ def _assert_validation_error(self, query: str, path: list, keyword: str):
+ """
+ Asserts that the query raises a validation error.
+ Checks the message to make sure it points to the right place.
+ """
+ query_string = quote(query, safe="")
+ response = self.client.get(
+ "/api/documents/?"
+ + "&".join(
+ (
+ f"custom_field_lookup={query_string}",
+ "ordering=archive_serial_number",
+ "page=1",
+ f"page_size={len(self.documents)}",
+ "truncate_content=true",
+ ),
+ ),
+ )
+ self.assertEqual(response.status_code, 400)
+
+ exception_path = []
+ detail = response.json()
+ while not isinstance(detail, list):
+ path_item, detail = next(iter(detail.items()))
+ exception_path.append(path_item)
+
+ self.assertEqual(path, exception_path)
+ self.assertIn(keyword, " ".join(detail))
+
+ # ==========================================================#
+ # Sanity checks #
+ # ==========================================================#
+ def test_name_value_association(self):
+ """
+ GIVEN:
+ - A document with `{"string_field": "https://docs.paperless-ngx.com/",
+ "url_field": "http://example.com/"}`
+ WHEN:
+ - Filtering by `["url_field", "exact", "https://docs.paperless-ngx.com/"]`
+ THEN:
+ - That document should not get matched.
+ """
+ self._assert_query_match_predicate(
+ ["url_field", "exact", "https://docs.paperless-ngx.com/"],
+ lambda document: "url_field" in document
+ and document["url_field"] == "https://docs.paperless-ngx.com/",
+ )
+
+ def test_filter_by_multiple_fields(self):
+ """
+ GIVEN:
+ - A document with `{"string_field": "https://docs.paperless-ngx.com/",
+ "url_field": "http://example.com/"}`
+ WHEN:
+ - Filtering by `['AND', [["string_field", "exists", True], ["url_field", "exists", True]]]`
+ THEN:
+ - That document should get matched.
+ """
+ self._assert_query_match_predicate(
+ ["AND", [["string_field", "exists", True], ["url_field", "exists", True]]],
+ lambda document: "url_field" in document and "string_field" in document,
+ )
+
+ # ==========================================================#
+ # Basic expressions supported by all custom field types #
+ # ==========================================================#
+ def test_exact(self):
+ self._assert_query_match_predicate(
+ ["string_field", "exact", "paperless"],
+ lambda document: "string_field" in document
+ and document["string_field"] == "paperless",
+ )
+
+ def test_in(self):
+ self._assert_query_match_predicate(
+ ["string_field", "in", ["paperless", "Paperless"]],
+ lambda document: "string_field" in document
+ and document["string_field"] in ("paperless", "Paperless"),
+ )
+
+ def test_isnull(self):
+ self._assert_query_match_predicate(
+ ["string_field", "isnull", True],
+ lambda document: "string_field" in document
+ and document["string_field"] is None,
+ )
+
+ def test_exists(self):
+ self._assert_query_match_predicate(
+ ["string_field", "exists", True],
+ lambda document: "string_field" in document,
+ )
+
+ def test_select(self):
+ # For select fields, you can either specify the index
+ # or the name of the option. They function exactly the same.
+ self._assert_query_match_predicate(
+ ["select_field", "exact", 1],
+ lambda document: "select_field" in document
+ and document["select_field"] == 1,
+ )
+ # This is the same as:
+ self._assert_query_match_predicate(
+ ["select_field", "exact", "B"],
+ lambda document: "select_field" in document
+ and document["select_field"] == 1,
+ )
+
+ # ==========================================================#
+ # Expressions for string, URL, and monetary fields #
+ # ==========================================================#
+ @pytest.mark.skipif(
+ not string_expr_opted_in("iexact"),
+ reason="iexact expr is disabled.",
+ )
+ def test_iexact(self):
+ self._assert_query_match_predicate(
+ ["string_field", "iexact", "paperless"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and document["string_field"].lower() == "paperless",
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("contains"),
+ reason="contains expr is disabled.",
+ )
+ def test_contains(self):
+ # WARNING: SQLite treats "contains" as "icontains"!
+ # You should avoid "contains" unless you know what you are doing!
+ self._assert_query_match_predicate(
+ ["string_field", "contains", "aper"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and "aper" in document["string_field"],
+ )
+
+ def test_icontains(self):
+ self._assert_query_match_predicate(
+ ["string_field", "icontains", "aper"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and "aper" in document["string_field"].lower(),
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("startswith"),
+ reason="startswith expr is disabled.",
+ )
+ def test_startswith(self):
+ # WARNING: SQLite treats "startswith" as "istartswith"!
+ # You should avoid "startswith" unless you know what you are doing!
+ self._assert_query_match_predicate(
+ ["string_field", "startswith", "paper"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and document["string_field"].startswith("paper"),
+ )
+
+ def test_istartswith(self):
+ self._assert_query_match_predicate(
+ ["string_field", "istartswith", "paper"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and document["string_field"].lower().startswith("paper"),
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("endswith"),
+ reason="endswith expr is disabled.",
+ )
+ def test_endswith(self):
+ # WARNING: SQLite treats "endswith" as "iendswith"!
+ # You should avoid "endswith" unless you know what you are doing!
+ self._assert_query_match_predicate(
+ ["string_field", "iendswith", "less"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and document["string_field"].lower().endswith("less"),
+ )
+
+ def test_iendswith(self):
+ self._assert_query_match_predicate(
+ ["string_field", "iendswith", "less"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and document["string_field"].lower().endswith("less"),
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("regex"),
+ reason="regex expr is disabled.",
+ )
+ def test_regex(self):
+ # WARNING: the regex syntax is database dependent!
+ self._assert_query_match_predicate(
+ ["string_field", "regex", r"^p.+s$"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and re.match(r"^p.+s$", document["string_field"]),
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("iregex"),
+ reason="iregex expr is disabled.",
+ )
+ def test_iregex(self):
+ # WARNING: the regex syntax is database dependent!
+ self._assert_query_match_predicate(
+ ["string_field", "iregex", r"^p.+s$"],
+ lambda document: "string_field" in document
+ and document["string_field"] is not None
+ and re.match(r"^p.+s$", document["string_field"], re.IGNORECASE),
+ )
+
+ def test_url_field_istartswith(self):
+ # URL fields supports all of the expressions above.
+ # Just showing one of them here.
+ self._assert_query_match_predicate(
+ ["url_field", "istartswith", "http://"],
+ lambda document: "url_field" in document
+ and document["url_field"] is not None
+ and document["url_field"].startswith("http://"),
+ )
+
+ @pytest.mark.skipif(
+ not string_expr_opted_in("iregex"),
+ reason="regex expr is disabled.",
+ )
+ def test_monetary_field_iregex(self):
+ # Monetary fields supports all of the expressions above.
+ # Just showing one of them here.
+ #
+ # Unfortunately we can't do arithmetic comparisons on monetary field,
+ # but you are welcome to use regex to do some of that.
+ # E.g., USD between 100.00 and 999.99:
+ self._assert_query_match_predicate(
+ ["monetary_field", "regex", r"USD[1-9][0-9]{2}\.[0-9]{2}"],
+ lambda document: "monetary_field" in document
+ and document["monetary_field"] is not None
+ and re.match(
+ r"USD[1-9][0-9]{2}\.[0-9]{2}",
+ document["monetary_field"],
+ re.IGNORECASE,
+ ),
+ )
+
+ # ==========================================================#
+ # Arithmetic comparisons #
+ # ==========================================================#
+ def test_gt(self):
+ self._assert_query_match_predicate(
+ ["date_field", "gt", date(2024, 8, 22).isoformat()],
+ lambda document: "date_field" in document
+ and document["date_field"] is not None
+ and document["date_field"] > date(2024, 8, 22),
+ )
+
+ def test_gte(self):
+ self._assert_query_match_predicate(
+ ["date_field", "gte", date(2024, 8, 22).isoformat()],
+ lambda document: "date_field" in document
+ and document["date_field"] is not None
+ and document["date_field"] >= date(2024, 8, 22),
+ )
+
+ def test_lt(self):
+ self._assert_query_match_predicate(
+ ["integer_field", "lt", 0],
+ lambda document: "integer_field" in document
+ and document["integer_field"] is not None
+ and document["integer_field"] < 0,
+ )
+
+ def test_lte(self):
+ self._assert_query_match_predicate(
+ ["integer_field", "lte", 0],
+ lambda document: "integer_field" in document
+ and document["integer_field"] is not None
+ and document["integer_field"] <= 0,
+ )
+
+ def test_range(self):
+ self._assert_query_match_predicate(
+ ["float_field", "range", [-0.05, 0.05]],
+ lambda document: "float_field" in document
+ and document["float_field"] is not None
+ and -0.05 <= document["float_field"] <= 0.05,
+ )
+
+ def test_date_modifier(self):
+ # For date fields you can optionally prefix the operator
+ # with the part of the date you are comparing with.
+ self._assert_query_match_predicate(
+ ["date_field", "year__gte", 2024],
+ lambda document: "date_field" in document
+ and document["date_field"] is not None
+ and document["date_field"].year >= 2024,
+ )
+
+ # ==========================================================#
+ # Subset check (document link field only) #
+ # ==========================================================#
+ def test_document_link_contains(self):
+ # Document link field "contains" performs a subset check.
+ self._assert_query_match_predicate(
+ ["documentlink_field", "contains", [1, 2]],
+ lambda document: "documentlink_field" in document
+ and document["documentlink_field"] is not None
+ and set(document["documentlink_field"]) >= {1, 2},
+ )
+ # The order of IDs don't matter - this is the same as above.
+ self._assert_query_match_predicate(
+ ["documentlink_field", "contains", [2, 1]],
+ lambda document: "documentlink_field" in document
+ and document["documentlink_field"] is not None
+ and set(document["documentlink_field"]) >= {1, 2},
+ )
+
+ def test_document_link_contains_empty_set(self):
+ # An empty set is a subset of any set.
+ self._assert_query_match_predicate(
+ ["documentlink_field", "contains", []],
+ lambda document: "documentlink_field" in document
+ and document["documentlink_field"] is not None,
+ )
+
+ def test_document_link_contains_no_reverse_link(self):
+ # An edge case is that the document in the value list
+ # doesn't have a document link field and thus has no reverse link.
+ self._assert_query_match_predicate(
+ ["documentlink_field", "contains", [self.documents[6].id]],
+ lambda document: "documentlink_field" in document
+ and document["documentlink_field"] is not None
+ and set(document["documentlink_field"]) >= {self.documents[6].id},
+ match_nothing_ok=True,
+ )
+
+ # ==========================================================#
+ # Logical expressions #
+ # ==========================================================#
+ def test_logical_and(self):
+ self._assert_query_match_predicate(
+ [
+ "AND",
+ [["date_field", "year__exact", 2024], ["date_field", "month__lt", 9]],
+ ],
+ lambda document: "date_field" in document
+ and document["date_field"] is not None
+ and document["date_field"].year == 2024
+ and document["date_field"].month < 9,
+ )
+
+ def test_logical_or(self):
+ # This is also the recommend way to check for "empty" text, URL, and monetary fields.
+ self._assert_query_match_predicate(
+ [
+ "OR",
+ [["string_field", "exact", ""], ["string_field", "isnull", True]],
+ ],
+ lambda document: "string_field" in document
+ and not bool(document["string_field"]),
+ )
+
+ def test_logical_not(self):
+ # This means `NOT ((document has string_field) AND (string_field iexact "paperless"))`,
+ # not `(document has string_field) AND (NOT (string_field iexact "paperless"))`!
+ self._assert_query_match_predicate(
+ [
+ "NOT",
+ ["string_field", "exact", "paperless"],
+ ],
+ lambda document: not (
+ "string_field" in document and document["string_field"] == "paperless"
+ ),
+ )
+
+ # ==========================================================#
+ # Tests for invalid queries #
+ # ==========================================================#
+
+ def test_invalid_json(self):
+ self._assert_validation_error(
+ "not valid json",
+ ["custom_field_lookup"],
+ "must be valid JSON",
+ )
+
+ def test_invalid_expression(self):
+ self._assert_validation_error(
+ json.dumps("valid json but not valid expr"),
+ ["custom_field_lookup"],
+ "Invalid custom field lookup expression",
+ )
+
+ def test_invalid_custom_field_name(self):
+ self._assert_validation_error(
+ json.dumps(["invalid name", "iexact", "foo"]),
+ ["custom_field_lookup", "0"],
+ "is not a valid custom field",
+ )
+
+ def test_invalid_operator(self):
+ self._assert_validation_error(
+ json.dumps(["integer_field", "iexact", "foo"]),
+ ["custom_field_lookup", "1"],
+ "does not support lookup expr",
+ )
+
+ def test_invalid_value(self):
+ self._assert_validation_error(
+ json.dumps(["select_field", "exact", "not an option"]),
+ ["custom_field_lookup", "2"],
+ "integer",
+ )
+
+ def test_invalid_logical_operator(self):
+ self._assert_validation_error(
+ json.dumps(["invalid op", ["integer_field", "gt", 0]]),
+ ["custom_field_lookup", "0"],
+ "Invalid logical operator",
+ )
+
+ def test_invalid_expr_list(self):
+ self._assert_validation_error(
+ json.dumps(["AND", "not a list"]),
+ ["custom_field_lookup", "1"],
+ "Invalid expression list",
+ )
+
+ def test_invalid_operator_prefix(self):
+ self._assert_validation_error(
+ json.dumps(["integer_field", "foo__gt", 0]),
+ ["custom_field_lookup", "1"],
+ "does not support lookup expr",
+ )
+
+ @pytest.mark.skipif(
+ string_expr_opted_in("regex"),
+ reason="user opted into allowing regex expr",
+ )
+ def test_disabled_operator(self):
+ self._assert_validation_error(
+ json.dumps(["string_field", "regex", r"^p.+s$"]),
+ ["custom_field_lookup", "1"],
+ "disabled by default",
+ )
+
+ def test_query_too_deep(self):
+ query = ["string_field", "exact", "paperless"]
+ for _ in range(10):
+ query = ["NOT", query]
+ self._assert_validation_error(
+ json.dumps(query),
+ ["custom_field_lookup", *(["1"] * 10)],
+ "Maximum nesting depth exceeded",
+ )
+
+ def test_query_too_many_atoms(self):
+ atom = ["string_field", "exact", "paperless"]
+ query = ["AND", [atom for _ in range(21)]]
+ self._assert_validation_error(
+ json.dumps(query),
+ ["custom_field_lookup", "1", "20"],
+ "Maximum number of query conditions exceeded",
+ )