]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Fix: support for custom field ordering w advanced search (#11383)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Mon, 17 Nov 2025 20:47:55 +0000 (12:47 -0800)
committerGitHub <noreply@github.com>
Mon, 17 Nov 2025 20:47:55 +0000 (20:47 +0000)
src/documents/index.py
src/documents/tests/test_api_search.py

index 90cbb8000c239d6fe99de38e1d139a88792be300..9446c7db17d99f2c8bb6a30618d829f6ec2b2d12 100644 (file)
@@ -287,15 +287,75 @@ class DelayedQuery:
         self.first_score = None
         self.filter_queryset = filter_queryset
         self.suggested_correction = None
+        self._manual_hits_cache: list | None = None
 
     def __len__(self) -> int:
+        if self._manual_sort_requested():
+            manual_hits = self._manual_hits()
+            return len(manual_hits)
+
         page = self[0:1]
         return len(page)
 
+    def _manual_sort_requested(self):
+        ordering = self.query_params.get("ordering", "")
+        return ordering.lstrip("-").startswith("custom_field_")
+
+    def _manual_hits(self):
+        if self._manual_hits_cache is None:
+            q, mask, suggested_correction = self._get_query()
+            self.suggested_correction = suggested_correction
+
+            results = self.searcher.search(
+                q,
+                mask=mask,
+                filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
+                limit=None,
+            )
+            results.fragmenter = highlight.ContextFragmenter(surround=50)
+            results.formatter = HtmlFormatter(tagname="span", between=" ... ")
+
+            if not self.first_score and len(results) > 0:
+                self.first_score = results[0].score
+
+            if self.first_score:
+                results.top_n = [
+                    (
+                        (hit[0] / self.first_score) if self.first_score else None,
+                        hit[1],
+                    )
+                    for hit in results.top_n
+                ]
+
+            hits_by_id = {hit["id"]: hit for hit in results}
+            matching_ids = list(hits_by_id.keys())
+
+            ordered_ids = list(
+                self.filter_queryset.filter(id__in=matching_ids).values_list(
+                    "id",
+                    flat=True,
+                ),
+            )
+            ordered_ids = list(dict.fromkeys(ordered_ids))
+
+            self._manual_hits_cache = [
+                hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id
+            ]
+        return self._manual_hits_cache
+
     def __getitem__(self, item):
         if item.start in self.saved_results:
             return self.saved_results[item.start]
 
+        if self._manual_sort_requested():
+            manual_hits = self._manual_hits()
+            start = 0 if item.start is None else item.start
+            stop = item.stop
+            hits = manual_hits[start:stop] if stop is not None else manual_hits[start:]
+            page = ManualResultsPage(hits)
+            self.saved_results[start] = page
+            return page
+
         q, mask, suggested_correction = self._get_query()
         self.suggested_correction = suggested_correction
         sortedby, reverse = self._get_query_sortedby()
@@ -315,21 +375,33 @@ class DelayedQuery:
         if not self.first_score and len(page.results) > 0 and sortedby is None:
             self.first_score = page.results[0].score
 
-        page.results.top_n = list(
-            map(
-                lambda hit: (
-                    (hit[0] / self.first_score) if self.first_score else None,
-                    hit[1],
-                ),
-                page.results.top_n,
-            ),
-        )
+        page.results.top_n = [
+            (
+                (hit[0] / self.first_score) if self.first_score else None,
+                hit[1],
+            )
+            for hit in page.results.top_n
+        ]
 
         self.saved_results[item.start] = page
 
         return page
 
 
+class ManualResultsPage(list):
+    def __init__(self, hits):
+        super().__init__(hits)
+        self.results = ManualResults(hits)
+
+
+class ManualResults:
+    def __init__(self, hits):
+        self._docnums = [hit.docnum for hit in hits]
+
+    def docs(self):
+        return self._docnums
+
+
 class LocalDateParser(English):
     def reverse_timezone_offset(self, d):
         return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(
index 8f316c1451c623ce9fff78de32f691fb8b229365..5a2fc9b52cb9e9f42ffec32a099691a8ceae656c 100644 (file)
@@ -89,6 +89,65 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
         self.assertEqual(len(results), 0)
         self.assertCountEqual(response.data["all"], [])
 
+    def test_search_custom_field_ordering(self):
+        custom_field = CustomField.objects.create(
+            name="Sortable field",
+            data_type=CustomField.FieldDataType.INT,
+        )
+        d1 = Document.objects.create(
+            title="first",
+            content="match",
+            checksum="A1",
+        )
+        d2 = Document.objects.create(
+            title="second",
+            content="match",
+            checksum="B2",
+        )
+        d3 = Document.objects.create(
+            title="third",
+            content="match",
+            checksum="C3",
+        )
+        CustomFieldInstance.objects.create(
+            document=d1,
+            field=custom_field,
+            value_int=30,
+        )
+        CustomFieldInstance.objects.create(
+            document=d2,
+            field=custom_field,
+            value_int=10,
+        )
+        CustomFieldInstance.objects.create(
+            document=d3,
+            field=custom_field,
+            value_int=20,
+        )
+
+        with AsyncWriter(index.open_index()) as writer:
+            index.update_document(writer, d1)
+            index.update_document(writer, d2)
+            index.update_document(writer, d3)
+
+        response = self.client.get(
+            f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(
+            [doc["id"] for doc in response.data["results"]],
+            [d2.id, d3.id, d1.id],
+        )
+
+        response = self.client.get(
+            f"/api/documents/?query=match&ordering=-custom_field_{custom_field.pk}",
+        )
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(
+            [doc["id"] for doc in response.data["results"]],
+            [d1.id, d3.id, d2.id],
+        )
+
     def test_search_multi_page(self):
         with AsyncWriter(index.open_index()) as writer:
             for i in range(55):