]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Refactor: Use django-filter logic for filtering full text search queries (#7507)
authorYichi Yang <helium6072@gmail.com>
Sun, 25 Aug 2024 04:20:43 +0000 (12:20 +0800)
committerGitHub <noreply@github.com>
Sun, 25 Aug 2024 04:20:43 +0000 (21:20 -0700)
src/documents/index.py
src/documents/tests/test_api_search.py
src/documents/tests/test_delayedquery.py
src/documents/views.py

index 98c43d1e883af64cbd8e34b9d5775c20679c5aa9..d95a80213f524de9b4395ccf16cda09f30d88f57 100644 (file)
@@ -8,8 +8,8 @@ from datetime import timezone
 from shutil import rmtree
 from typing import Optional
 
-from dateutil.parser import isoparse
 from django.conf import settings
+from django.db.models import QuerySet
 from django.utils import timezone as django_timezone
 from guardian.shortcuts import get_users_with_perms
 from whoosh import classify
@@ -22,6 +22,8 @@ from whoosh.fields import NUMERIC
 from whoosh.fields import TEXT
 from whoosh.fields import Schema
 from whoosh.highlight import HtmlFormatter
+from whoosh.idsets import BitSet
+from whoosh.idsets import DocIdSet
 from whoosh.index import FileIndex
 from whoosh.index import create_in
 from whoosh.index import exists_in
@@ -31,6 +33,7 @@ from whoosh.qparser import QueryParser
 from whoosh.qparser.dateparse import DateParserPlugin
 from whoosh.qparser.dateparse import English
 from whoosh.qparser.plugins import FieldsPlugin
+from whoosh.reading import IndexReader
 from whoosh.scoring import TF_IDF
 from whoosh.searching import ResultsPage
 from whoosh.searching import Searcher
@@ -201,114 +204,32 @@ def remove_document_from_index(document: Document):
         remove_document(writer, document)
 
 
-class DelayedQuery:
-    param_map = {
-        "correspondent": ("correspondent", ["id", "id__in", "id__none", "isnull"]),
-        "document_type": ("type", ["id", "id__in", "id__none", "isnull"]),
-        "storage_path": ("path", ["id", "id__in", "id__none", "isnull"]),
-        "owner": ("owner", ["id", "id__in", "id__none", "isnull"]),
-        "shared_by": ("shared_by", ["id"]),
-        "tags": ("tag", ["id__all", "id__in", "id__none"]),
-        "added": ("added", ["date__lt", "date__gt"]),
-        "created": ("created", ["date__lt", "date__gt"]),
-        "checksum": ("checksum", ["icontains", "istartswith"]),
-        "original_filename": ("original_filename", ["icontains", "istartswith"]),
-        "custom_fields": (
-            "custom_fields",
-            ["icontains", "istartswith", "id__all", "id__in", "id__none"],
-        ),
-    }
+class MappedDocIdSet(DocIdSet):
+    """
+    A DocIdSet backed by a set of `Document` IDs.
+    Supports efficiently looking up if a whoosh docnum is in the provided `filter_queryset`.
+    """
 
-    def _get_query(self):
-        raise NotImplementedError
-
-    def _get_query_filter(self):
-        criterias = []
-        for key, value in self.query_params.items():
-            # is_tagged is a special case
-            if key == "is_tagged":
-                criterias.append(query.Term("has_tag", self.evalBoolean(value)))
-                continue
-
-            if key == "has_custom_fields":
-                criterias.append(
-                    query.Term("has_custom_fields", self.evalBoolean(value)),
-                )
-                continue
-
-            # Don't process query params without a filter
-            if "__" not in key:
-                continue
-
-            # All other query params consist of a parameter and a query filter
-            param, query_filter = key.split("__", 1)
-            try:
-                field, supported_query_filters = self.param_map[param]
-            except KeyError:
-                logger.error(f"Unable to build a query filter for parameter {key}")
-                continue
-
-            # We only support certain filters per parameter
-            if query_filter not in supported_query_filters:
-                logger.info(
-                    f"Query filter {query_filter} not supported for parameter {param}",
-                )
-                continue
-
-            if query_filter == "id":
-                if param == "shared_by":
-                    criterias.append(query.Term("is_shared", True))
-                    criterias.append(query.Term("owner_id", value))
-                else:
-                    criterias.append(query.Term(f"{field}_id", value))
-            elif query_filter == "id__in":
-                in_filter = []
-                for object_id in value.split(","):
-                    in_filter.append(
-                        query.Term(f"{field}_id", object_id),
-                    )
-                criterias.append(query.Or(in_filter))
-            elif query_filter == "id__none":
-                for object_id in value.split(","):
-                    criterias.append(
-                        query.Not(query.Term(f"{field}_id", object_id)),
-                    )
-            elif query_filter == "isnull":
-                criterias.append(
-                    query.Term(f"has_{field}", self.evalBoolean(value) is False),
-                )
-            elif query_filter == "id__all":
-                for object_id in value.split(","):
-                    criterias.append(query.Term(f"{field}_id", object_id))
-            elif query_filter == "date__lt":
-                criterias.append(
-                    query.DateRange(field, start=None, end=isoparse(value)),
-                )
-            elif query_filter == "date__gt":
-                criterias.append(
-                    query.DateRange(field, start=isoparse(value), end=None),
-                )
-            elif query_filter == "icontains":
-                criterias.append(
-                    query.Term(field, value),
-                )
-            elif query_filter == "istartswith":
-                criterias.append(
-                    query.Prefix(field, value),
-                )
-
-        user_criterias = get_permissions_criterias(
-            user=self.user,
-        )
-        if len(criterias) > 0:
-            if len(user_criterias) > 0:
-                criterias.append(query.Or(user_criterias))
-            return query.And(criterias)
-        else:
-            return query.Or(user_criterias) if len(user_criterias) > 0 else None
+    def __init__(self, filter_queryset: QuerySet, ixreader: IndexReader) -> None:
+        super().__init__()
+        document_ids = filter_queryset.order_by("id").values_list("id", flat=True)
+        max_id = document_ids.last() or 0
+        self.document_ids = BitSet(document_ids, size=max_id)
+        self.ixreader = ixreader
 
-    def evalBoolean(self, val):
-        return val.lower() in {"true", "1"}
+    def __contains__(self, docnum):
+        document_id = self.ixreader.stored_fields(docnum)["id"]
+        return document_id in self.document_ids
+
+    def __bool__(self):
+        # searcher.search ignores a filter if it's "falsy".
+        # We use this hack so this DocIdSet, when used as a filter, is never ignored.
+        return True
+
+
+class DelayedQuery:
+    def _get_query(self):
+        raise NotImplementedError  # pragma: no cover
 
     def _get_query_sortedby(self):
         if "ordering" not in self.query_params:
@@ -339,13 +260,19 @@ class DelayedQuery:
         else:
             return sort_fields_map[field], reverse
 
-    def __init__(self, searcher: Searcher, query_params, page_size, user):
+    def __init__(
+        self,
+        searcher: Searcher,
+        query_params,
+        page_size,
+        filter_queryset: QuerySet,
+    ):
         self.searcher = searcher
         self.query_params = query_params
         self.page_size = page_size
         self.saved_results = dict()
         self.first_score = None
-        self.user = user
+        self.filter_queryset = filter_queryset
 
     def __len__(self):
         page = self[0:1]
@@ -361,7 +288,7 @@ class DelayedQuery:
         page: ResultsPage = self.searcher.search_page(
             q,
             mask=mask,
-            filter=self._get_query_filter(),
+            filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
             pagenum=math.floor(item.start / self.page_size) + 1,
             pagelen=self.page_size,
             sortedby=sortedby,
index c10d6c1bb2c8359ce62e0fca05be218a2f50628e..e524e7b91b3ab4bb63796ecadcc117e04cfd51ff 100644 (file)
@@ -15,6 +15,7 @@ from rest_framework.test import APITestCase
 from whoosh.writing import AsyncWriter
 
 from documents import index
+from documents.bulk_edit import set_permissions
 from documents.models import Correspondent
 from documents.models import CustomField
 from documents.models import CustomFieldInstance
@@ -1159,7 +1160,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
             [d3.id, d2.id, d1.id],
         )
 
-    def test_global_search(self):
+    @mock.patch("documents.bulk_edit.bulk_update_documents")
+    def test_global_search(self, m):
         """
         GIVEN:
             - Multiple documents and objects
@@ -1186,11 +1188,38 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
             checksum="C",
             pk=3,
         )
+        # The below two documents are owned by user2 and shouldn't show up in results!
+        d4 = Document.objects.create(
+            title="doc 4 owned by user2",
+            content="bank bank bank bank 4",
+            checksum="D",
+            pk=4,
+        )
+        d5 = Document.objects.create(
+            title="doc 5 owned by user2",
+            content="bank bank bank bank 5",
+            checksum="E",
+            pk=5,
+        )
+
+        user1 = User.objects.create_user("bank user1")
+        user2 = User.objects.create_superuser("user2")
+        group1 = Group.objects.create(name="bank group1")
+        Group.objects.create(name="group2")
+
+        user1.user_permissions.add(
+            *Permission.objects.filter(codename__startswith="view_").exclude(
+                content_type__app_label="admin",
+            ),
+        )
+        set_permissions([4, 5], set_permissions=[], owner=user2, merge=False)
 
         with index.open_index_writer() as writer:
             index.update_document(writer, d1)
             index.update_document(writer, d2)
             index.update_document(writer, d3)
+            index.update_document(writer, d4)
+            index.update_document(writer, d5)
 
         correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
         Correspondent.objects.create(name="correspondent 2")
@@ -1200,10 +1229,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
         StoragePath.objects.create(name="path 2", path="path2")
         tag1 = Tag.objects.create(name="bank tag1")
         Tag.objects.create(name="tag2")
-        user1 = User.objects.create_superuser("bank user1")
-        User.objects.create_user("user2")
-        group1 = Group.objects.create(name="bank group1")
-        Group.objects.create(name="group2")
+
         SavedView.objects.create(
             name="bank view",
             show_on_dashboard=True,
index b0dfc2ed2b07f0bccbd814b9eb6224e4d162fd13..1895bd6c69440c221fd9489a8b6c1dabc26828b0 100644 (file)
@@ -1,8 +1,6 @@
-from dateutil.parser import isoparse
 from django.test import TestCase
 from whoosh import query
 
-from documents.index import DelayedQuery
 from documents.index import get_permissions_criterias
 from documents.models import User
 
@@ -58,162 +56,3 @@ class TestDelayedQuery(TestCase):
         )
         for user, expected in tests:
             self.assertEqual(get_permissions_criterias(user), expected)
-
-    def test_no_query_filters(self):
-        dq = DelayedQuery(None, {}, None, None)
-        self.assertEqual(dq._get_query_filter(), self.has_no_owner)
-
-    def test_date_query_filters(self):
-        def _get_testset(param: str):
-            date_str = "1970-01-01T02:44"
-            date_obj = isoparse(date_str)
-            return (
-                (
-                    {f"{param}__date__lt": date_str},
-                    query.And(
-                        [
-                            query.DateRange(param, start=None, end=date_obj),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-                (
-                    {f"{param}__date__gt": date_str},
-                    query.And(
-                        [
-                            query.DateRange(param, start=date_obj, end=None),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-            )
-
-        query_params = ["created", "added"]
-        for param in query_params:
-            for params, expected in _get_testset(param):
-                dq = DelayedQuery(None, params, None, None)
-                got = dq._get_query_filter()
-                self.assertCountEqual(got, expected)
-
-    def test_is_tagged_query_filter(self):
-        tests = (
-            ("True", True),
-            ("true", True),
-            ("1", True),
-            ("False", False),
-            ("false", False),
-            ("0", False),
-            ("foo", False),
-        )
-        for param, expected in tests:
-            dq = DelayedQuery(None, {"is_tagged": param}, None, None)
-            self.assertEqual(
-                dq._get_query_filter(),
-                query.And([query.Term("has_tag", expected), self.has_no_owner]),
-            )
-
-    def test_tags_query_filters(self):
-        # tests contains tuples of query_parameter dics and the expected whoosh query
-        param = "tags"
-        field, _ = DelayedQuery.param_map[param]
-        tests = (
-            (
-                {f"{param}__id__all": "42,43"},
-                query.And(
-                    [
-                        query.Term(f"{field}_id", "42"),
-                        query.Term(f"{field}_id", "43"),
-                        self.has_no_owner,
-                    ],
-                ),
-            ),
-            # tags does not allow __id
-            (
-                {f"{param}__id": "42"},
-                self.has_no_owner,
-            ),
-            # tags does not allow __isnull
-            (
-                {f"{param}__isnull": "true"},
-                self.has_no_owner,
-            ),
-            self._get_testset__id__in(param, field),
-            self._get_testset__id__none(param, field),
-        )
-
-        for params, expected in tests:
-            dq = DelayedQuery(None, params, None, None)
-            got = dq._get_query_filter()
-            self.assertCountEqual(got, expected)
-
-    def test_generic_query_filters(self):
-        def _get_testset(param: str):
-            field, _ = DelayedQuery.param_map[param]
-            return (
-                (
-                    {f"{param}__id": "42"},
-                    query.And(
-                        [
-                            query.Term(f"{field}_id", "42"),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-                self._get_testset__id__in(param, field),
-                self._get_testset__id__none(param, field),
-                (
-                    {f"{param}__isnull": "true"},
-                    query.And(
-                        [
-                            query.Term(f"has_{field}", False),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-                (
-                    {f"{param}__isnull": "false"},
-                    query.And(
-                        [
-                            query.Term(f"has_{field}", True),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-            )
-
-        query_params = ["correspondent", "document_type", "storage_path", "owner"]
-        for param in query_params:
-            for params, expected in _get_testset(param):
-                dq = DelayedQuery(None, params, None, None)
-                got = dq._get_query_filter()
-                self.assertCountEqual(got, expected)
-
-    def test_char_query_filter(self):
-        def _get_testset(param: str):
-            return (
-                (
-                    {f"{param}__icontains": "foo"},
-                    query.And(
-                        [
-                            query.Term(f"{param}", "foo"),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-                (
-                    {f"{param}__istartswith": "foo"},
-                    query.And(
-                        [
-                            query.Prefix(f"{param}", "foo"),
-                            self.has_no_owner,
-                        ],
-                    ),
-                ),
-            )
-
-        query_params = ["checksum", "original_filename"]
-        for param in query_params:
-            for params, expected in _get_testset(param):
-                dq = DelayedQuery(None, params, None, None)
-                got = dq._get_query_filter()
-                self.assertCountEqual(got, expected)
index df54546e198986061d2edaa25f4f5267560d6910..c0ceef4a384f61d793f38adea5aef2d784ed1b17 100644 (file)
@@ -852,6 +852,8 @@ class UnifiedSearchViewSet(DocumentViewSet):
         )
 
     def filter_queryset(self, queryset):
+        filtered_queryset = super().filter_queryset(queryset)
+
         if self._is_search_request():
             from documents import index
 
@@ -866,10 +868,10 @@ class UnifiedSearchViewSet(DocumentViewSet):
                 self.searcher,
                 self.request.query_params,
                 self.paginator.get_page_size(self.request),
-                self.request.user,
+                filter_queryset=filtered_queryset,
             )
         else:
-            return super().filter_queryset(queryset)
+            return filtered_queryset
 
     def list(self, request, *args, **kwargs):
         if self._is_search_request():
@@ -1199,14 +1201,16 @@ class GlobalSearchView(PassUserMixin):
                 from documents import index
 
                 with index.open_index_searcher() as s:
-                    q, _ = index.DelayedFullTextQuery(
+                    fts_query = index.DelayedFullTextQuery(
                         s,
                         request.query_params,
-                        10,
-                        request.user,
-                    )._get_query()
-                    results = s.search(q, limit=OBJECT_LIMIT)
-                    docs = docs | all_docs.filter(id__in=[r["id"] for r in results])
+                        OBJECT_LIMIT,
+                        filter_queryset=all_docs,
+                    )
+                    results = fts_query[0:1]
+                    docs = docs | Document.objects.filter(
+                        id__in=[r["id"] for r in results],
+                    )
             docs = docs[:OBJECT_LIMIT]
         saved_views = (
             SavedView.objects.filter(owner=request.user, name__icontains=query)
@@ -1452,12 +1456,12 @@ class StatisticsView(APIView):
             {
                 "documents_total": documents_total,
                 "documents_inbox": documents_inbox,
-                "inbox_tag": inbox_tags.first().pk
-                if inbox_tags.exists()
-                else None,  # backwards compatibility
-                "inbox_tags": [tag.pk for tag in inbox_tags]
-                if inbox_tags.exists()
-                else None,
+                "inbox_tag": (
+                    inbox_tags.first().pk if inbox_tags.exists() else None
+                ),  # backwards compatibility
+                "inbox_tags": (
+                    [tag.pk for tag in inbox_tags] if inbox_tags.exists() else None
+                ),
                 "document_file_type_counts": document_file_type_counts,
                 "character_count": character_count,
                 "tag_count": len(tags),