import re
from django.conf import settings
+from django.core.cache import cache
from documents.models import Document, MatchingModel
)
return None
- try:
+ version = os.stat(settings.MODEL_FILE).st_mtime
+
+ classifier = cache.get("paperless-classifier", version=version)
+
+ if not classifier:
classifier = DocumentClassifier()
- classifier.reload()
- except (EOFError, IncompatibleClassifierVersionError) as e:
- # there's something wrong with the model file.
- logger.error(
- f"Unrecoverable error while loading document "
- f"classification model: {str(e)}, deleting model file."
- )
- os.unlink(settings.MODEL_FILE)
- classifier = None
- except OSError as e:
- logger.error(
- f"Error while loading document classification model: {str(e)}"
- )
- classifier = None
+ try:
+ classifier.load()
+ cache.set("paperless-classifier", classifier, version=version)
+ except (EOFError, IncompatibleClassifierVersionError) as e:
+ # there's something wrong with the model file.
+ logger.error(
+ f"Unrecoverable error while loading document "
+ f"classification model: {str(e)}, deleting model file."
+ )
+ os.unlink(settings.MODEL_FILE)
+ classifier = None
+ except OSError as e:
+ logger.error(
+ f"Error while loading document classification model: {str(e)}"
+ )
+ classifier = None
return classifier
FORMAT_VERSION = 6
def __init__(self):
- # mtime of the model file on disk. used to prevent reloading when
- # nothing has changed.
- self.classifier_version = 0
-
# hash of the training data. used to prevent re-training when the
# training data has not changed.
self.data_hash = None
self.correspondent_classifier = None
self.document_type_classifier = None
- def reload(self):
- if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
- with open(settings.MODEL_FILE, "rb") as f:
- schema_version = pickle.load(f)
-
- if schema_version != self.FORMAT_VERSION:
- raise IncompatibleClassifierVersionError(
- "Cannor load classifier, incompatible versions.")
- else:
- if self.classifier_version > 0:
- # Don't be confused by this check. It's simply here
- # so that we wont log anything on initial reload.
- logger.info("Classifier updated on disk, "
- "reloading classifier models")
- self.data_hash = pickle.load(f)
- self.data_vectorizer = pickle.load(f)
- self.tags_binarizer = pickle.load(f)
-
- self.tags_classifier = pickle.load(f)
- self.correspondent_classifier = pickle.load(f)
- self.document_type_classifier = pickle.load(f)
- self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
-
- def save_classifier(self):
+ def load(self):
+ with open(settings.MODEL_FILE, "rb") as f:
+ schema_version = pickle.load(f)
+
+ if schema_version != self.FORMAT_VERSION:
+ raise IncompatibleClassifierVersionError(
+ "Cannor load classifier, incompatible versions.")
+ else:
+ self.data_hash = pickle.load(f)
+ self.data_vectorizer = pickle.load(f)
+ self.tags_binarizer = pickle.load(f)
+
+ self.tags_classifier = pickle.load(f)
+ self.correspondent_classifier = pickle.load(f)
+ self.document_type_classifier = pickle.load(f)
+
+ def save(self):
with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.data_hash, f)
import os
import tempfile
from pathlib import Path
-from time import sleep
from unittest import mock
from django.conf import settings
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
- self.classifier.save_classifier()
+ self.classifier.save()
classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers.
- self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
+ self.assertRaises(IncompatibleClassifierVersionError, classifier2.load)
- self.classifier.save_classifier()
+ self.classifier.save()
# assure that we can load the classifier after saving it.
- classifier2.reload()
-
- def testReload(self):
-
- self.generate_test_data()
- self.assertTrue(self.classifier.train())
- self.classifier.save_classifier()
-
- classifier2 = DocumentClassifier()
- classifier2.reload()
- v1 = classifier2.classifier_version
-
- # change the classifier after some time.
- sleep(1)
- self.classifier.save_classifier()
-
- classifier2.reload()
- v2 = classifier2.classifier_version
- self.assertNotEqual(v1, v2)
+ classifier2.load()
@override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self):
self.classifier.train()
- self.classifier.save_classifier()
+ self.classifier.save()
new_classifier = DocumentClassifier()
- new_classifier.reload()
+ new_classifier.load()
self.assertFalse(new_classifier.train())
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
self.generate_test_data()
new_classifier = DocumentClassifier()
- new_classifier.reload()
+ new_classifier.load()
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
self.assertFalse(os.path.exists(settings.MODEL_FILE))
self.assertIsNone(load_classifier())
- @mock.patch("documents.classifier.DocumentClassifier.reload")
- def test_load_classifier(self, reload):
+ @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+ @mock.patch("documents.classifier.DocumentClassifier.load")
+ def test_load_classifier(self, load):
Path(settings.MODEL_FILE).touch()
self.assertIsNotNone(load_classifier())
+ load.assert_called_once()
+
+ @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}})
+ @override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
+ def test_load_classifier_cached(self):
+ classifier = load_classifier()
+ self.assertIsNotNone(classifier)
+
+ with mock.patch("documents.classifier.DocumentClassifier.load") as load:
+ classifier2 = load_classifier()
+ load.assert_not_called()
- @mock.patch("documents.classifier.DocumentClassifier.reload")
- def test_load_classifier_incompatible_version(self, reload):
+ @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+ @mock.patch("documents.classifier.DocumentClassifier.load")
+ def test_load_classifier_incompatible_version(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
- reload.side_effect = IncompatibleClassifierVersionError()
+ load.side_effect = IncompatibleClassifierVersionError()
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
- @mock.patch("documents.classifier.DocumentClassifier.reload")
- def test_load_classifier_os_error(self, reload):
+ @override_settings(CACHES={'default': {'BACKEND': 'django.core.cache.backends.dummy.DummyCache'}})
+ @mock.patch("documents.classifier.DocumentClassifier.load")
+ def test_load_classifier_os_error(self, load):
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
- reload.side_effect = OSError()
+ load.side_effect = OSError()
self.assertIsNone(load_classifier())
self.assertTrue(os.path.exists(settings.MODEL_FILE))