import logging
import math
import os
+from collections import Counter
from contextlib import contextmanager
from dateutil.parser import isoparse
from whoosh.fields import TEXT
from whoosh.fields import Schema
from whoosh.highlight import HtmlFormatter
+from whoosh.index import FileIndex
from whoosh.index import create_in
from whoosh.index import exists_in
from whoosh.index import open_dir
from whoosh.qparser import MultifieldParser
+from whoosh.qparser import QueryParser
from whoosh.qparser.dateparse import DateParserPlugin
+from whoosh.scoring import TF_IDF
from whoosh.searching import ResultsPage
from whoosh.searching import Searcher
from whoosh.writing import AsyncWriter
from documents.models import Document
from documents.models import Note
+from documents.models import User
logger = logging.getLogger("paperless.index")
elif k == "storage_path__isnull":
criterias.append(query.Term("has_path", v == "false"))
- user_criterias = [query.Term("has_owner", False)]
- if "user" in self.query_params:
- if self.query_params["is_superuser"]: # superusers see all docs
- user_criterias = []
- else:
- user_criterias.append(query.Term("owner_id", self.query_params["user"]))
- user_criterias.append(
- query.Term("viewer_id", str(self.query_params["user"])),
- )
+ user_criterias = get_permissions_criterias(
+ user=self.user,
+ )
if len(criterias) > 0:
if len(user_criterias) > 0:
criterias.append(query.Or(user_criterias))
else:
return sort_fields_map[field], reverse
- def __init__(self, searcher: Searcher, query_params, page_size):
+ def __init__(self, searcher: Searcher, query_params, page_size, user):
self.searcher = searcher
self.query_params = query_params
self.page_size = page_size
self.saved_results = dict()
self.first_score = None
+ self.user = user
def __len__(self):
page = self[0:1]
return q, mask
-def autocomplete(ix, term, limit=10):
- with ix.reader() as reader:
- terms = []
- for score, t in reader.most_distinctive_terms(
- "content",
- number=limit,
- prefix=term.lower(),
- ):
- terms.append(t)
- return terms
+def autocomplete(ix: FileIndex, term: str, limit: int = 10, user: User = None):
+ """
+ Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions
+ and without scoring
+ """
+ terms = []
+
+ with ix.searcher(weighting=TF_IDF()) as s:
+ qp = QueryParser("content", schema=ix.schema)
+ q = qp.parse(f"{term.lower()}*")
+ user_criterias = get_permissions_criterias(user)
+
+ results = s.search(
+ q,
+ terms=True,
+ filter=query.Or(user_criterias) if user_criterias is not None else None,
+ )
+
+ termCounts = Counter()
+ if results.has_matched_terms():
+ for hit in results:
+ for _, term in hit.matched_terms():
+ termCounts[term] += 1
+ terms = [t for t, _ in termCounts.most_common(limit)]
+
+ return terms
+
+
+def get_permissions_criterias(user: User = None):
+ user_criterias = [query.Term("has_owner", False)]
+ if user is not None:
+ if user.is_superuser: # superusers see all docs
+ user_criterias = []
+ else:
+ user_criterias.append(query.Term("owner_id", user.id))
+ user_criterias.append(
+ query.Term("viewer_id", str(user.id)),
+ )
+ return user_criterias
@mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m):
- m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]
+ m.side_effect = lambda ix, term, limit, user: [term for _ in range(limit)]
response = self.client.get("/api/search/autocomplete/?term=test")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 10)
+ def test_search_autocomplete_respect_permissions(self):
+ """
+ GIVEN:
+ - Multiple users and documents with & without permissions
+ WHEN:
+ - API reuqest for autocomplete is made by user with or without permissions
+ THEN:
+ - Terms only within docs user has access to are returned
+ """
+ u1 = User.objects.create_user("user1")
+ u2 = User.objects.create_user("user2")
+
+ self.client.force_authenticate(user=u1)
+
+ d1 = Document.objects.create(
+ title="doc1",
+ content="apples",
+ checksum="1",
+ owner=u1,
+ )
+ d2 = Document.objects.create(
+ title="doc2",
+ content="applebaum",
+ checksum="2",
+ owner=u1,
+ )
+ d3 = Document.objects.create(
+ title="doc3",
+ content="appletini",
+ checksum="3",
+ owner=u1,
+ )
+
+ with AsyncWriter(index.open_index()) as writer:
+ index.update_document(writer, d1)
+ index.update_document(writer, d2)
+ index.update_document(writer, d3)
+
+ response = self.client.get("/api/search/autocomplete/?term=app")
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
+
+ d3.owner = u2
+
+ with AsyncWriter(index.open_index()) as writer:
+ index.update_document(writer, d3)
+
+ response = self.client.get("/api/search/autocomplete/?term=app")
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [b"apples", b"applebaum"])
+
+ assign_perm("view_document", u1, d3)
+
+ with AsyncWriter(index.open_index()) as writer:
+ index.update_document(writer, d3)
+
+ response = self.client.get("/api/search/autocomplete/?term=app")
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
+
@pytest.mark.skip(reason="Not implemented yet")
def test_search_spelling_correction(self):
with AsyncWriter(index.open_index()) as writer:
self.assertListEqual(
index.autocomplete(ix, "tes"),
- [b"test3", b"test", b"test2"],
+ [b"test2", b"test", b"test3"],
)
self.assertListEqual(
index.autocomplete(ix, "tes", limit=3),
- [b"test3", b"test", b"test2"],
+ [b"test2", b"test", b"test3"],
)
- self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test3"])
+ self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test2"])
self.assertListEqual(index.autocomplete(ix, "tes", limit=0), [])
def test_archive_serial_number_ranging(self):
if self._is_search_request():
from documents import index
- if hasattr(self.request, "user"):
- # pass user to query for perms
- self.request.query_params._mutable = True
- self.request.query_params["user"] = self.request.user.id
- self.request.query_params[
- "is_superuser"
- ] = self.request.user.is_superuser
- self.request.query_params._mutable = False
-
if "query" in self.request.query_params:
query_class = index.DelayedFullTextQuery
elif "more_like_id" in self.request.query_params:
self.searcher,
self.request.query_params,
self.paginator.get_page_size(self.request),
+ self.request.user,
)
else:
return super().filter_queryset(queryset)
permission_classes = (IsAuthenticated,)
def get(self, request, format=None):
+ user = self.request.user if hasattr(self.request, "user") else None
+
if "term" in request.query_params:
term = request.query_params["term"]
else:
ix = index.open_index()
- return Response(index.autocomplete(ix, term, limit))
+ return Response(
+ index.autocomplete(
+ ix,
+ term,
+ limit,
+ user,
+ ),
+ )
class StatisticsView(APIView):