]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Changes from a hash based system to a time based system to prevent extra retrains
authorTrenton Holmes <797416+stumpylog@users.noreply.github.com>
Thu, 23 Feb 2023 04:03:23 +0000 (20:03 -0800)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Tue, 28 Feb 2023 16:13:10 +0000 (08:13 -0800)
src/documents/classifier.py
src/documents/tests/data/model.pickle [deleted file]
src/documents/tests/test_classifier.py

index 66958087a1d4d467976d59e5bd653fdf57163e0b..9a57281240b59f35dd59b1881baa400cca015124 100644 (file)
@@ -1,10 +1,10 @@
-import hashlib
 import logging
 import os
 import pickle
 import re
 import shutil
 import warnings
+from datetime import datetime
 from typing import Iterator
 from typing import List
 from typing import Optional
@@ -62,12 +62,13 @@ class DocumentClassifier:
 
     # v7 - Updated scikit-learn package version
     # v8 - Added storage path classifier
-    FORMAT_VERSION = 8
+    # v9 - Changed from hash to time for training data check
+    FORMAT_VERSION = 9
 
     def __init__(self):
-        # hash of the training data. used to prevent re-training when the
+        # last time training data was calculated. used to prevent re-training when the
         # training data has not changed.
-        self.data_hash: Optional[bytes] = None
+        self.last_data_change: Optional[datetime] = None
 
         self.data_vectorizer = None
         self.tags_binarizer = None
@@ -91,7 +92,7 @@ class DocumentClassifier:
                     )
                 else:
                     try:
-                        self.data_hash = pickle.load(f)
+                        self.last_data_change = pickle.load(f)
                         self.data_vectorizer = pickle.load(f)
                         self.tags_binarizer = pickle.load(f)
 
@@ -121,7 +122,7 @@ class DocumentClassifier:
 
         with open(target_file_temp, "wb") as f:
             pickle.dump(self.FORMAT_VERSION, f)
-            pickle.dump(self.data_hash, f)
+            pickle.dump(self.last_data_change, f)
             pickle.dump(self.data_vectorizer, f)
 
             pickle.dump(self.tags_binarizer, f)
@@ -137,35 +138,40 @@ class DocumentClassifier:
 
     def train(self):
 
+        # Get non-inbox documents
+        docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
+
+        # No documents exit to train against
+        if docs_queryset.count() == 0:
+            raise ValueError("No training data available.")
+
+        # No documents have changed since classifier was trained
+        latest_doc_change = docs_queryset.latest("modified").modified
+        if (
+            self.last_data_change is not None
+            and self.last_data_change >= latest_doc_change
+        ):
+            return False
+
         labels_tags = []
         labels_correspondent = []
         labels_document_type = []
         labels_storage_path = []
 
-        docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
-
-        if docs_queryset.count() == 0:
-            raise ValueError("No training data available.")
-
         # Step 1: Extract and preprocess training data from the database.
         logger.debug("Gathering data from database...")
-        m = hashlib.sha1()
         for doc in docs_queryset:
-            preprocessed_content = self.preprocess_content(doc.content)
-            m.update(preprocessed_content.encode("utf-8"))
 
             y = -1
             dt = doc.document_type
             if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
                 y = dt.pk
-            m.update(y.to_bytes(4, "little", signed=True))
             labels_document_type.append(y)
 
             y = -1
             cor = doc.correspondent
             if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
                 y = cor.pk
-            m.update(y.to_bytes(4, "little", signed=True))
             labels_correspondent.append(y)
 
             tags = sorted(
@@ -174,22 +180,14 @@ class DocumentClassifier:
                     matching_algorithm=MatchingModel.MATCH_AUTO,
                 )
             )
-            for tag in tags:
-                m.update(tag.to_bytes(4, "little", signed=True))
             labels_tags.append(tags)
 
             y = -1
             sd = doc.storage_path
             if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
                 y = sd.pk
-            m.update(y.to_bytes(4, "little", signed=True))
             labels_storage_path.append(y)
 
-        new_data_hash = m.digest()
-
-        if self.data_hash and new_data_hash == self.data_hash:
-            return False
-
         labels_tags_unique = {tag for tags in labels_tags for tag in tags}
 
         num_tags = len(labels_tags_unique)
@@ -216,12 +214,16 @@ class DocumentClassifier:
 
         from sklearn.feature_extraction.text import CountVectorizer
         from sklearn.neural_network import MLPClassifier
-        from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
+        from sklearn.preprocessing import LabelBinarizer
+        from sklearn.preprocessing import MultiLabelBinarizer
 
         # Step 2: vectorize data
         logger.debug("Vectorizing data...")
 
         def content_generator() -> Iterator[str]:
+            """
+            Generates the content for documents, but once at a time
+            """
             for doc in docs_queryset:
                 yield self.preprocess_content(doc.content)
 
@@ -299,7 +301,7 @@ class DocumentClassifier:
                 "There are no storage paths. Not training storage path classifier.",
             )
 
-        self.data_hash = new_data_hash
+        self.last_data_change = latest_doc_change
 
         return True
 
diff --git a/src/documents/tests/data/model.pickle b/src/documents/tests/data/model.pickle
deleted file mode 100644 (file)
index ff88b88..0000000
Binary files a/src/documents/tests/data/model.pickle and /dev/null differ
index 057f67204dec730bf68f377c062b542832692a0e..7653fd861a759522a8d077f5f9df5f705a8a53cf 100644 (file)
@@ -1,7 +1,5 @@
 import os
 import re
-import shutil
-import tempfile
 from pathlib import Path
 from unittest import mock
 
@@ -22,15 +20,15 @@ from documents.tests.utils import DirectoriesMixin
 
 
 def dummy_preprocess(content: str):
+    """
+    Simpler, faster pre-processing for testing purposes
+    """
     content = content.lower().strip()
     content = re.sub(r"\s+", " ", content)
     return content
 
 
 class TestClassifier(DirectoriesMixin, TestCase):
-
-    SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
-
     def setUp(self):
         super().setUp()
         self.classifier = DocumentClassifier()
@@ -111,17 +109,68 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         self.doc1.storage_path = self.sp1
 
-    def testNoTrainingData(self):
-        try:
+    def generate_train_and_save(self):
+        """
+        Generates the training data, trains and saves the updated pickle
+        file. This ensures the test is using the same scikit learn version
+        and eliminates a warning from the test suite
+        """
+        self.generate_test_data()
+        self.classifier.train()
+        self.classifier.save()
+
+    def test_no_training_data(self):
+        """
+        GIVEN:
+            - No documents exist to train
+        WHEN:
+            - Classifier training is requested
+        THEN:
+            - Exception is raised
+        """
+        with self.assertRaisesMessage(ValueError, "No training data available."):
+            self.classifier.train()
+
+    def test_no_non_inbox_tags(self):
+        """
+        GIVEN:
+            - No documents without an inbox tag exist
+        WHEN:
+            - Classifier training is requested
+        THEN:
+            - Exception is raised
+        """
+
+        t1 = Tag.objects.create(
+            name="t1",
+            matching_algorithm=Tag.MATCH_ANY,
+            pk=34,
+            is_inbox_tag=True,
+        )
+
+        doc1 = Document.objects.create(
+            title="doc1",
+            content="this is a document from c1",
+            checksum="A",
+        )
+        doc1.tags.add(t1)
+
+        with self.assertRaisesMessage(ValueError, "No training data available."):
             self.classifier.train()
-        except ValueError as e:
-            self.assertEqual(str(e), "No training data available.")
-        else:
-            self.fail("Should raise exception")
 
     def testEmpty(self):
+        """
+        GIVEN:
+            - A document exists
+            - No tags/not enough data to predict
+        WHEN:
+            - Classifier prediction is requested
+        THEN:
+            - Classifier returns no predictions
+        """
         Document.objects.create(title="WOW", checksum="3457", content="ASD")
         self.classifier.train()
+
         self.assertIsNone(self.classifier.document_type_classifier)
         self.assertIsNone(self.classifier.tags_classifier)
         self.assertIsNone(self.classifier.correspondent_classifier)
@@ -131,8 +180,18 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.assertIsNone(self.classifier.predict_correspondent(""))
 
     def testTrain(self):
+        """
+        GIVEN:
+            - Test data
+        WHEN:
+            - Classifier is trained
+        THEN:
+            - Classifier uses correct values for correspondent learning
+            - Classifier uses correct values for tags learning
+        """
         self.generate_test_data()
         self.classifier.train()
+
         self.assertListEqual(
             list(self.classifier.correspondent_classifier.classes_),
             [-1, self.c1.pk],
@@ -143,8 +202,17 @@ class TestClassifier(DirectoriesMixin, TestCase):
         )
 
     def testPredict(self):
+        """
+        GIVEN:
+            - Classifier trained against test data
+        WHEN:
+            - Prediction requested for correspondent, tags, type
+        THEN:
+            - Expected predictions based on training set
+        """
         self.generate_test_data()
         self.classifier.train()
+
         self.assertEqual(
             self.classifier.predict_correspondent(self.doc1.content),
             self.c1.pk,
@@ -164,20 +232,51 @@ class TestClassifier(DirectoriesMixin, TestCase):
         )
         self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
 
-    def testDatasetHashing(self):
+    def test_no_retrain_if_no_change(self):
+        """
+        GIVEN:
+            - Classifier trained with current data
+        WHEN:
+            - Classifier training is requested again
+        THEN:
+            - Classifier does not redo training
+        """
 
         self.generate_test_data()
 
         self.assertTrue(self.classifier.train())
         self.assertFalse(self.classifier.train())
 
-    def testVersionIncreased(self):
+    def test_retrain_if_change(self):
+        """
+        GIVEN:
+            - Classifier trained with current data
+        WHEN:
+            - Classifier training is requested again
+            - Documents have changed
+        THEN:
+            - Classifier does not redo training
+        """
 
         self.generate_test_data()
+
         self.assertTrue(self.classifier.train())
-        self.assertFalse(self.classifier.train())
 
-        self.classifier.save()
+        self.doc1.correspondent = self.c2
+        self.doc1.save()
+
+        self.assertTrue(self.classifier.train())
+
+    def testVersionIncreased(self):
+        """
+        GIVEN:
+            - Existing classifier model saved at a version
+        WHEN:
+            - Attempt to load classifier file from newer version
+        THEN:
+            - Exception is raised
+        """
+        self.generate_train_and_save()
 
         classifier2 = DocumentClassifier()
 
@@ -194,14 +293,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
             # assure that we can load the classifier after saving it.
             classifier2.load()
 
-    @override_settings(DATA_DIR=tempfile.mkdtemp())
     def testSaveClassifier(self):
 
-        self.generate_test_data()
-
-        self.classifier.train()
-
-        self.classifier.save()
+        self.generate_train_and_save()
 
         new_classifier = DocumentClassifier()
         new_classifier.load()
@@ -209,25 +303,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         self.assertFalse(new_classifier.train())
 
-    # @override_settings(
-    #     MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
-    # )
-    # def test_create_test_load_and_classify(self):
-    #     self.generate_test_data()
-    #     self.classifier.train()
-    #     self.classifier.save()
-
     def test_load_and_classify(self):
-        # Generate test data, train and save to the model file
-        # This ensures the model file sklearn version matches
-        # and eliminates a warning
-        shutil.copy(
-            self.SAMPLE_MODEL_FILE,
-            os.path.join(self.dirs.data_dir, "classification_model.pickle"),
-        )
-        self.generate_test_data()
-        self.classifier.train()
-        self.classifier.save()
+
+        self.generate_train_and_save()
 
         new_classifier = DocumentClassifier()
         new_classifier.load()
@@ -245,11 +323,9 @@ class TestClassifier(DirectoriesMixin, TestCase):
         THEN:
             - The ClassifierModelCorruptError is raised
         """
-        shutil.copy(
-            self.SAMPLE_MODEL_FILE,
-            os.path.join(self.dirs.data_dir, "classification_model.pickle"),
-        )
-        # First load is the schema version
+        self.generate_train_and_save()
+
+        # First load is the schema version,allow it
         patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
 
         with self.assertRaises(ClassifierModelCorruptError):