self.first_score = None
self.filter_queryset = filter_queryset
self.suggested_correction = None
+ self._manual_hits_cache: list | None = None
def __len__(self) -> int:
+ if self._manual_sort_requested():
+ manual_hits = self._manual_hits()
+ return len(manual_hits)
+
page = self[0:1]
return len(page)
+ def _manual_sort_requested(self):
+ ordering = self.query_params.get("ordering", "")
+ return ordering.lstrip("-").startswith("custom_field_")
+
+ def _manual_hits(self):
+ if self._manual_hits_cache is None:
+ q, mask, suggested_correction = self._get_query()
+ self.suggested_correction = suggested_correction
+
+ results = self.searcher.search(
+ q,
+ mask=mask,
+ filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
+ limit=None,
+ )
+ results.fragmenter = highlight.ContextFragmenter(surround=50)
+ results.formatter = HtmlFormatter(tagname="span", between=" ... ")
+
+ if not self.first_score and len(results) > 0:
+ self.first_score = results[0].score
+
+ if self.first_score:
+ results.top_n = [
+ (
+ (hit[0] / self.first_score) if self.first_score else None,
+ hit[1],
+ )
+ for hit in results.top_n
+ ]
+
+ hits_by_id = {hit["id"]: hit for hit in results}
+ matching_ids = list(hits_by_id.keys())
+
+ ordered_ids = list(
+ self.filter_queryset.filter(id__in=matching_ids).values_list(
+ "id",
+ flat=True,
+ ),
+ )
+ ordered_ids = list(dict.fromkeys(ordered_ids))
+
+ self._manual_hits_cache = [
+ hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id
+ ]
+ return self._manual_hits_cache
+
def __getitem__(self, item):
if item.start in self.saved_results:
return self.saved_results[item.start]
+ if self._manual_sort_requested():
+ manual_hits = self._manual_hits()
+ start = 0 if item.start is None else item.start
+ stop = item.stop
+ hits = manual_hits[start:stop] if stop is not None else manual_hits[start:]
+ page = ManualResultsPage(hits)
+ self.saved_results[start] = page
+ return page
+
q, mask, suggested_correction = self._get_query()
self.suggested_correction = suggested_correction
sortedby, reverse = self._get_query_sortedby()
if not self.first_score and len(page.results) > 0 and sortedby is None:
self.first_score = page.results[0].score
- page.results.top_n = list(
- map(
- lambda hit: (
- (hit[0] / self.first_score) if self.first_score else None,
- hit[1],
- ),
- page.results.top_n,
- ),
- )
+ page.results.top_n = [
+ (
+ (hit[0] / self.first_score) if self.first_score else None,
+ hit[1],
+ )
+ for hit in page.results.top_n
+ ]
self.saved_results[item.start] = page
return page
+class ManualResultsPage(list):
+ def __init__(self, hits):
+ super().__init__(hits)
+ self.results = ManualResults(hits)
+
+
+class ManualResults:
+ def __init__(self, hits):
+ self._docnums = [hit.docnum for hit in hits]
+
+ def docs(self):
+ return self._docnums
+
+
class LocalDateParser(English):
def reverse_timezone_offset(self, d):
return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(
self.assertEqual(len(results), 0)
self.assertCountEqual(response.data["all"], [])
+ def test_search_custom_field_ordering(self):
+ custom_field = CustomField.objects.create(
+ name="Sortable field",
+ data_type=CustomField.FieldDataType.INT,
+ )
+ d1 = Document.objects.create(
+ title="first",
+ content="match",
+ checksum="A1",
+ )
+ d2 = Document.objects.create(
+ title="second",
+ content="match",
+ checksum="B2",
+ )
+ d3 = Document.objects.create(
+ title="third",
+ content="match",
+ checksum="C3",
+ )
+ CustomFieldInstance.objects.create(
+ document=d1,
+ field=custom_field,
+ value_int=30,
+ )
+ CustomFieldInstance.objects.create(
+ document=d2,
+ field=custom_field,
+ value_int=10,
+ )
+ CustomFieldInstance.objects.create(
+ document=d3,
+ field=custom_field,
+ value_int=20,
+ )
+
+ with AsyncWriter(index.open_index()) as writer:
+ index.update_document(writer, d1)
+ index.update_document(writer, d2)
+ index.update_document(writer, d3)
+
+ response = self.client.get(
+ f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}",
+ )
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(
+ [doc["id"] for doc in response.data["results"]],
+ [d2.id, d3.id, d1.id],
+ )
+
+ response = self.client.get(
+ f"/api/documents/?query=match&ordering=-custom_field_{custom_field.pk}",
+ )
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(
+ [doc["id"] for doc in response.data["results"]],
+ [d1.id, d3.id, d2.id],
+ )
+
def test_search_multi_page(self):
with AsyncWriter(index.open_index()) as writer:
for i in range(55):