]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Also handles confirming returned predictions are still automatic matching, in case...
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Mon, 24 Jul 2023 17:19:47 +0000 (10:19 -0700)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Mon, 24 Jul 2023 19:31:56 +0000 (12:31 -0700)
src/documents/matching.py
src/documents/tests/test_consumer.py

index 521d492844d4273d94c38e684ea26537e2010ec8..a7ceb5a5ae06a709fdcc5271744dc538072d51e4 100644 (file)
@@ -35,7 +35,11 @@ def match_correspondents(document, classifier, user=None):
         correspondents = Correspondent.objects.all()
 
     return list(
-        filter(lambda o: matches(o, document) or o.pk == pred_id, correspondents),
+        filter(
+            lambda o: matches(o, document)
+            or (o.pk == pred_id and o.matching_algorithm == MatchingModel.MATCH_AUTO),
+            correspondents,
+        ),
     )
 
 
@@ -55,7 +59,11 @@ def match_document_types(document, classifier, user=None):
         document_types = DocumentType.objects.all()
 
     return list(
-        filter(lambda o: matches(o, document) or o.pk == pred_id, document_types),
+        filter(
+            lambda o: matches(o, document)
+            or (o.pk == pred_id and o.matching_algorithm == MatchingModel.MATCH_AUTO),
+            document_types,
+        ),
     )
 
 
@@ -71,7 +79,14 @@ def match_tags(document, classifier, user=None):
         tags = Tag.objects.all()
 
     return list(
-        filter(lambda o: matches(o, document) or o.pk in predicted_tag_ids, tags),
+        filter(
+            lambda o: matches(o, document)
+            or (
+                o.matching_algorithm == MatchingModel.MATCH_AUTO
+                and o.pk in predicted_tag_ids
+            ),
+            tags,
+        ),
     )
 
 
@@ -92,7 +107,8 @@ def match_storage_paths(document, classifier, user=None):
 
     return list(
         filter(
-            lambda o: matches(o, document) or o.pk == pred_id,
+            lambda o: matches(o, document)
+            or (o.pk == pred_id and o.matching_algorithm == MatchingModel.MATCH_AUTO),
             storage_paths,
         ),
     )
index 441cffddfafcd26429c6af35e75a78aa1693a850..13806635582430a07e5f832712e966a54bfb3063 100644 (file)
@@ -561,10 +561,16 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
     @mock.patch("documents.consumer.load_classifier")
     def testClassifyDocument(self, m):
-        correspondent = Correspondent.objects.create(name="test")
-        dtype = DocumentType.objects.create(name="test")
-        t1 = Tag.objects.create(name="t1")
-        t2 = Tag.objects.create(name="t2")
+        correspondent = Correspondent.objects.create(
+            name="test",
+            matching_algorithm=Correspondent.MATCH_AUTO,
+        )
+        dtype = DocumentType.objects.create(
+            name="test",
+            matching_algorithm=DocumentType.MATCH_AUTO,
+        )
+        t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO)
+        t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO)
 
         m.return_value = MagicMock()
         m.return_value.predict_correspondent.return_value = correspondent.pk