]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Returns to using hashing against primary keys, at least for fields. Improves testing...
authorTrenton Holmes <797416+stumpylog@users.noreply.github.com>
Mon, 27 Feb 2023 05:01:29 +0000 (21:01 -0800)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Tue, 28 Feb 2023 16:13:10 +0000 (08:13 -0800)
src/documents/classifier.py
src/documents/tests/test_classifier.py

index 9a57281240b59f35dd59b1881baa400cca015124..ce2441f847d3086c1a0c61a5312b471146183c49 100644 (file)
@@ -5,6 +5,7 @@ import re
 import shutil
 import warnings
 from datetime import datetime
+from hashlib import sha256
 from typing import Iterator
 from typing import List
 from typing import Optional
@@ -51,7 +52,7 @@ def load_classifier() -> Optional["DocumentClassifier"]:
     except OSError:
         logger.exception("IO error while loading document classification model")
         classifier = None
-    except Exception:
+    except Exception:  # pragma: nocover
         logger.exception("Unknown error while loading document classification model")
         classifier = None
 
@@ -62,13 +63,14 @@ class DocumentClassifier:
 
     # v7 - Updated scikit-learn package version
     # v8 - Added storage path classifier
-    # v9 - Changed from hash to time for training data check
+    # v9 - Changed from hashing to time/ids for re-train check
     FORMAT_VERSION = 9
 
     def __init__(self):
-        # last time training data was calculated. used to prevent re-training when the
-        # training data has not changed.
-        self.last_data_change: Optional[datetime] = None
+        # last time a document changed and therefore training might be required
+        self.last_doc_change_time: Optional[datetime] = None
+        # Hash of primary keys of AUTO matching values last used in training
+        self.last_auto_type_hash: Optional[bytes] = None
 
         self.data_vectorizer = None
         self.tags_binarizer = None
@@ -92,7 +94,9 @@ class DocumentClassifier:
                     )
                 else:
                     try:
-                        self.last_data_change = pickle.load(f)
+                        self.last_doc_change_time = pickle.load(f)
+                        self.last_auto_type_hash = pickle.load(f)
+
                         self.data_vectorizer = pickle.load(f)
                         self.tags_binarizer = pickle.load(f)
 
@@ -122,7 +126,9 @@ class DocumentClassifier:
 
         with open(target_file_temp, "wb") as f:
             pickle.dump(self.FORMAT_VERSION, f)
-            pickle.dump(self.last_data_change, f)
+            pickle.dump(self.last_doc_change_time, f)
+            pickle.dump(self.last_auto_type_hash, f)
+
             pickle.dump(self.data_vectorizer, f)
 
             pickle.dump(self.tags_binarizer, f)
@@ -139,20 +145,14 @@ class DocumentClassifier:
     def train(self):
 
         # Get non-inbox documents
-        docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
+        docs_queryset = Document.objects.exclude(
+            tags__is_inbox_tag=True,
+        )
 
         # No documents exit to train against
         if docs_queryset.count() == 0:
             raise ValueError("No training data available.")
 
-        # No documents have changed since classifier was trained
-        latest_doc_change = docs_queryset.latest("modified").modified
-        if (
-            self.last_data_change is not None
-            and self.last_data_change >= latest_doc_change
-        ):
-            return False
-
         labels_tags = []
         labels_correspondent = []
         labels_document_type = []
@@ -160,18 +160,21 @@ class DocumentClassifier:
 
         # Step 1: Extract and preprocess training data from the database.
         logger.debug("Gathering data from database...")
+        hasher = sha256()
         for doc in docs_queryset:
 
             y = -1
             dt = doc.document_type
             if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
                 y = dt.pk
+            hasher.update(y.to_bytes(4, "little", signed=True))
             labels_document_type.append(y)
 
             y = -1
             cor = doc.correspondent
             if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
                 y = cor.pk
+            hasher.update(y.to_bytes(4, "little", signed=True))
             labels_correspondent.append(y)
 
             tags = sorted(
@@ -180,18 +183,31 @@ class DocumentClassifier:
                     matching_algorithm=MatchingModel.MATCH_AUTO,
                 )
             )
+            for tag in tags:
+                hasher.update(tag.to_bytes(4, "little", signed=True))
             labels_tags.append(tags)
 
             y = -1
-            sd = doc.storage_path
-            if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
-                y = sd.pk
+            sp = doc.storage_path
+            if sp and sp.matching_algorithm == MatchingModel.MATCH_AUTO:
+                y = sp.pk
+            hasher.update(y.to_bytes(4, "little", signed=True))
             labels_storage_path.append(y)
 
         labels_tags_unique = {tag for tags in labels_tags for tag in tags}
 
         num_tags = len(labels_tags_unique)
 
+        # Check if retraining is actually required.
+        # A document has been updated since the classifier was trained
+        # New auto tags, types, correspondent, storage paths exist
+        latest_doc_change = docs_queryset.latest("modified").modified
+        if (
+            self.last_doc_change_time is not None
+            and self.last_doc_change_time >= latest_doc_change
+        ) and self.last_auto_type_hash == hasher.digest():
+            return False
+
         # substract 1 since -1 (null) is also part of the classes.
 
         # union with {-1} accounts for cases where all documents have
@@ -301,11 +317,12 @@ class DocumentClassifier:
                 "There are no storage paths. Not training storage path classifier.",
             )
 
-        self.last_data_change = latest_doc_change
+        self.last_doc_change_time = latest_doc_change
+        self.last_auto_type_hash = hasher.digest()
 
         return True
 
-    def preprocess_content(self, content: str) -> str:
+    def preprocess_content(self, content: str) -> str:  # pragma: nocover
         """
         Process to contents of a document, distilling it down into
         words which are meaningful to the content
index 7653fd861a759522a8d077f5f9df5f705a8a53cf..1dad8e128dbc7345e3d924e69841bcd7afe8b5eb 100644 (file)
@@ -14,6 +14,7 @@ from documents.classifier import load_classifier
 from documents.models import Correspondent
 from documents.models import Document
 from documents.models import DocumentType
+from documents.models import MatchingModel
 from documents.models import StoragePath
 from documents.models import Tag
 from documents.tests.utils import DirectoriesMixin
@@ -46,6 +47,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
             name="c3",
             matching_algorithm=Correspondent.MATCH_AUTO,
         )
+
         self.t1 = Tag.objects.create(
             name="t1",
             matching_algorithm=Tag.MATCH_AUTO,
@@ -62,6 +64,12 @@ class TestClassifier(DirectoriesMixin, TestCase):
             matching_algorithm=Tag.MATCH_AUTO,
             pk=45,
         )
+        self.t4 = Tag.objects.create(
+            name="t4",
+            matching_algorithm=Tag.MATCH_ANY,
+            pk=46,
+        )
+
         self.dt = DocumentType.objects.create(
             name="dt",
             matching_algorithm=DocumentType.MATCH_AUTO,
@@ -70,6 +78,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
             name="dt2",
             matching_algorithm=DocumentType.MATCH_AUTO,
         )
+
         self.sp1 = StoragePath.objects.create(
             name="sp1",
             path="path1",
@@ -80,6 +89,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
             path="path2",
             matching_algorithm=DocumentType.MATCH_AUTO,
         )
+        self.store_paths = [self.sp1, self.sp2]
 
         self.doc1 = Document.objects.create(
             title="doc1",
@@ -87,6 +97,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
             correspondent=self.c1,
             checksum="A",
             document_type=self.dt,
+            storage_path=self.sp1,
         )
 
         self.doc2 = Document.objects.create(
@@ -107,8 +118,6 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.doc2.tags.add(self.t3)
         self.doc_inbox.tags.add(self.t2)
 
-        self.doc1.storage_path = self.sp1
-
     def generate_train_and_save(self):
         """
         Generates the training data, trains and saves the updated pickle
@@ -267,6 +276,28 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         self.assertTrue(self.classifier.train())
 
+    def test_retrain_if_auto_match_set_changed(self):
+        """
+        GIVEN:
+            - Classifier trained with current data
+        WHEN:
+            - Classifier training is requested again
+            - Some new AUTO match object exists
+        THEN:
+            - Classifier does redo training
+        """
+        self.generate_test_data()
+        # Add the ANY type
+        self.doc1.tags.add(self.t4)
+
+        self.assertTrue(self.classifier.train())
+
+        # Change the matching type
+        self.t4.matching_algorithm = MatchingModel.MATCH_AUTO
+        self.t4.save()
+
+        self.assertTrue(self.classifier.train())
+
     def testVersionIncreased(self):
         """
         GIVEN:
@@ -314,7 +345,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
 
     @mock.patch("documents.classifier.pickle.load")
-    def test_load_corrupt_file(self, patched_pickle_load):
+    def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
         """
         GIVEN:
             - Corrupted classifier pickle file
@@ -330,14 +361,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         with self.assertRaises(ClassifierModelCorruptError):
             self.classifier.load()
+            patched_pickle_load.assert_called()
+
+        patched_pickle_load.reset_mock()
+        patched_pickle_load.side_effect = [
+            DocumentClassifier.FORMAT_VERSION,
+            ClassifierModelCorruptError(),
+        ]
+
+        self.assertIsNone(load_classifier())
+        patched_pickle_load.assert_called()
 
-    @override_settings(
-        MODEL_FILE=os.path.join(
-            os.path.dirname(__file__),
-            "data",
-            "v1.0.2.model.pickle",
-        ),
-    )
     def test_load_new_scikit_learn_version(self):
         """
         GIVEN:
@@ -347,9 +381,12 @@ class TestClassifier(DirectoriesMixin, TestCase):
         THEN:
             - The classifier reports the warning was captured and processed
         """
-
-        with self.assertRaises(IncompatibleClassifierVersionError):
-            self.classifier.load()
+        # TODO: This wasn't testing the warning anymore, as the schema changed
+        # but as it was implemented, it would require installing an old version
+        # rebuilding the file and committing that.  Not developer friendly
+        # Need to rethink how to pass the load through to a file with a single
+        # old model?
+        pass
 
     def test_one_correspondent_predict(self):
         c1 = Correspondent.objects.create(