]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
classifier caching
authorjonaswinkler <jonas.winkler@jpwinkler.de>
Sat, 6 Feb 2021 19:54:58 +0000 (20:54 +0100)
committerjonaswinkler <jonas.winkler@jpwinkler.de>
Sat, 6 Feb 2021 19:54:58 +0000 (20:54 +0100)
src/documents/classifier.py
src/documents/tasks.py
src/documents/tests/test_classifier.py
src/documents/tests/test_tasks.py

index 2acebe04a91086811446888c2a444437ba526d70..5151d453c225d8edfbb2271c16cfbe02d6abca23 100755 (executable)
@@ -5,6 +5,7 @@ import pickle
 import re
 
 from django.conf import settings
+from django.core.cache import cache
 
 from documents.models import Document, MatchingModel
 
@@ -30,22 +31,28 @@ def load_classifier():
         )
         return None
 
-    try:
+    version = os.stat(settings.MODEL_FILE).st_mtime
+
+    classifier = cache.get("paperless-classifier", version=version)
+
+    if not classifier:
         classifier = DocumentClassifier()
-        classifier.reload()
-    except (EOFError, IncompatibleClassifierVersionError) as e:
-        # there's something wrong with the model file.
-        logger.error(
-            f"Unrecoverable error while loading document "
-            f"classification model: {str(e)}, deleting model file."
-        )
-        os.unlink(settings.MODEL_FILE)
-        classifier = None
-    except OSError as e:
-        logger.error(
-            f"Error while loading document classification model: {str(e)}"
-        )
-        classifier = None
+        try:
+            classifier.load()
+            cache.set("paperless-classifier", classifier, version=version)
+        except (EOFError, IncompatibleClassifierVersionError) as e:
+            # there's something wrong with the model file.
+            logger.error(
+                f"Unrecoverable error while loading document "
+                f"classification model: {str(e)}, deleting model file."
+            )
+            os.unlink(settings.MODEL_FILE)
+            classifier = None
+        except OSError as e:
+            logger.error(
+                f"Error while loading document classification model: {str(e)}"
+            )
+            classifier = None
 
     return classifier
 
@@ -55,10 +62,6 @@ class DocumentClassifier(object):
     FORMAT_VERSION = 6
 
     def __init__(self):
-        # mtime of the model file on disk. used to prevent reloading when
-        # nothing has changed.
-        self.classifier_version = 0
-
         # hash of the training data. used to prevent re-training when the
         # training data has not changed.
         self.data_hash = None
@@ -69,30 +72,23 @@ class DocumentClassifier(object):
         self.correspondent_classifier = None
         self.document_type_classifier = None
 
-    def reload(self):
-        if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
-            with open(settings.MODEL_FILE, "rb") as f:
-                schema_version = pickle.load(f)
-
-                if schema_version != self.FORMAT_VERSION:
-                    raise IncompatibleClassifierVersionError(
-                        "Cannor load classifier, incompatible versions.")
-                else:
-                    if self.classifier_version > 0:
-                        # Don't be confused by this check. It's simply here
-                        # so that we wont log anything on initial reload.
-                        logger.info("Classifier updated on disk, "
-                                    "reloading classifier models")
-                    self.data_hash = pickle.load(f)
-                    self.data_vectorizer = pickle.load(f)
-                    self.tags_binarizer = pickle.load(f)
-
-                    self.tags_classifier = pickle.load(f)
-                    self.correspondent_classifier = pickle.load(f)
-                    self.document_type_classifier = pickle.load(f)
-            self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
-
-    def save_classifier(self):
+    def load(self):
+        with open(settings.MODEL_FILE, "rb") as f:
+            schema_version = pickle.load(f)
+
+            if schema_version != self.FORMAT_VERSION:
+                raise IncompatibleClassifierVersionError(
+                    "Cannor load classifier, incompatible versions.")
+            else:
+                self.data_hash = pickle.load(f)
+                self.data_vectorizer = pickle.load(f)
+                self.tags_binarizer = pickle.load(f)
+
+                self.tags_classifier = pickle.load(f)
+                self.correspondent_classifier = pickle.load(f)
+                self.document_type_classifier = pickle.load(f)
+
+    def save(self):
         with open(settings.MODEL_FILE, "wb") as f:
             pickle.dump(self.FORMAT_VERSION, f)
             pickle.dump(self.data_hash, f)
index 8c7d91585f25cfc9a312276aa611c6b926fe4925..f74a3d4208167b91e2e50767c7b81619619f66cb 100644 (file)
@@ -52,7 +52,7 @@ def train_classifier():
                 "Saving updated classifier model to {}...".format(
                     settings.MODEL_FILE)
             )
-            classifier.save_classifier()
+            classifier.save()
         else:
             logger.debug(
                 "Training data unchanged."
index 14673ae655610d66332b41011fd4147f857b9ea4..1efe564d1c1859486c1b6990c943debb07d6b552 100644 (file)
@@ -1,7 +1,6 @@
 import os
 import tempfile
 from pathlib import Path
-from time import sleep
 from unittest import mock
 
 from django.conf import settings
@@ -85,37 +84,19 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.assertTrue(self.classifier.train())
         self.assertFalse(self.classifier.train())
 
-        self.classifier.save_classifier()
+        self.classifier.save()
 
         classifier2 = DocumentClassifier()
 
         current_ver = DocumentClassifier.FORMAT_VERSION
         with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
             # assure that we won't load old classifiers.
-            self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
+            self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
 
-            self.classifier.save_classifier()
+            self.classifier.save()
 
             # assure that we can load the classifier after saving it.
-            classifier2.reload()
-
-    def testReload(self):
-
-        self.generate_test_data()
-        self.assertTrue(self.classifier.train())
-        self.classifier.save_classifier()
-
-        classifier2 = DocumentClassifier()
-        classifier2.reload()
-        v1 = classifier2.classifier_version
-
-        # change the classifier after some time.
-        sleep(1)
-        self.classifier.save_classifier()
-
-        classifier2.reload()
-        v2 = classifier2.classifier_version
-        self.assertNotEqual(v1, v2)
+            classifier2.load()
 
     @override_settings(DATA_DIR=tempfile.mkdtemp())
     def testSaveClassifier(self):
@@ -124,10 +105,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         self.classifier.train()
 
-        self.classifier.save_classifier()
+        self.classifier.save()
 
         new_classifier = DocumentClassifier()
-        new_classifier.reload()
+        new_classifier.load()
         self.assertFalse(new_classifier.train())
 
     @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
@@ -135,7 +116,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.generate_test_data()
 
         new_classifier = DocumentClassifier()
-        new_classifier.reload()
+        new_classifier.load()
 
         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
 
@@ -252,25 +233,39 @@ class TestClassifier(DirectoriesMixin, TestCase):
         self.assertFalse(os.path.exists(settings.MODEL_FILE))
         self.assertIsNone(load_classifier())
 
-    @mock.patch("documents.classifier.DocumentClassifier.reload")
-    def test_load_classifier(self, reload):
+    @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+    @mock.patch("documents.classifier.DocumentClassifier.load")
+    def test_load_classifier(self, load):
         Path(settings.MODEL_FILE).touch()
         self.assertIsNotNone(load_classifier())
+        load.assert_called_once()
+
+    @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
+    @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
+    def test_load_classifier_cached(self):
+        classifier = load_classifier()
+        self.assertIsNotNone(classifier)
+
+        with mock.patch("documents.classifier.DocumentClassifier.load") as load:
+            classifier2 = load_classifier()
+            load.assert_not_called()
 
-    @mock.patch("documents.classifier.DocumentClassifier.reload")
-    def test_load_classifier_incompatible_version(self, reload):
+    @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+    @mock.patch("documents.classifier.DocumentClassifier.load")
+    def test_load_classifier_incompatible_version(self, load):
         Path(settings.MODEL_FILE).touch()
         self.assertTrue(os.path.exists(settings.MODEL_FILE))
 
-        reload.side_effect = IncompatibleClassifierVersionError()
+        load.side_effect = IncompatibleClassifierVersionError()
         self.assertIsNone(load_classifier())
         self.assertFalse(os.path.exists(settings.MODEL_FILE))
 
-    @mock.patch("documents.classifier.DocumentClassifier.reload")
-    def test_load_classifier_os_error(self, reload):
+    @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+    @mock.patch("documents.classifier.DocumentClassifier.load")
+    def test_load_classifier_os_error(self, load):
         Path(settings.MODEL_FILE).touch()
         self.assertTrue(os.path.exists(settings.MODEL_FILE))
 
-        reload.side_effect = OSError()
+        load.side_effect = OSError()
         self.assertIsNone(load_classifier())
         self.assertTrue(os.path.exists(settings.MODEL_FILE))
index d008f995a2dcbd5fca9f16c0c4e2298c29ba7f0c..ed280441fb9b837d1dc4e502c9c969a8aabdf3f0 100644 (file)
@@ -2,7 +2,7 @@ import os
 from unittest import mock
 
 from django.conf import settings
-from django.test import TestCase
+from django.test import TestCase, override_settings
 from django.utils import timezone
 
 from documents import tasks
@@ -52,6 +52,7 @@ class TestTasks(DirectoriesMixin, TestCase):
         load_classifier.assert_called_once()
         self.assertFalse(os.path.isfile(settings.MODEL_FILE))
 
+    @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
     def test_train_classifier(self):
         c = Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
         doc = Document.objects.create(correspondent=c, content="test", title="test")