import pickle
import re
import shutil
+import warnings
+from typing import Optional
from django.conf import settings
from documents.models import Document
logger = logging.getLogger("paperless.classifier")
-def preprocess_content(content):
+def preprocess_content(content: str) -> str:
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
-def load_classifier():
+def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE):
logger.debug(
"Document classification model does not exist (yet), not "
try:
classifier.load()
- except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
+ except IncompatibleClassifierVersionError:
+ logger.info("Classifier version updated, will re-train")
+ os.unlink(settings.MODEL_FILE)
+ classifier = None
+ except ClassifierModelCorruptError:
# there's something wrong with the model file.
logger.exception(
"Unrecoverable error while loading document "
class DocumentClassifier:
+ # v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
FORMAT_VERSION = 8
def __init__(self):
# hash of the training data. used to prevent re-training when the
# training data has not changed.
- self.data_hash = None
+ self.data_hash: Optional[bytes] = None
self.data_vectorizer = None
self.tags_binarizer = None
self.storage_path_classifier = None
def load(self):
- with open(settings.MODEL_FILE, "rb") as f:
- schema_version = pickle.load(f)
-
- if schema_version != self.FORMAT_VERSION:
- raise IncompatibleClassifierVersionError(
- "Cannot load classifier, incompatible versions.",
+ # Catch warnings for processing
+ with warnings.catch_warnings(record=True) as w:
+ with open(settings.MODEL_FILE, "rb") as f:
+ schema_version = pickle.load(f)
+
+ if schema_version != self.FORMAT_VERSION:
+ raise IncompatibleClassifierVersionError(
+ "Cannot load classifier, incompatible versions.",
+ )
+ else:
+ try:
+ self.data_hash = pickle.load(f)
+ self.data_vectorizer = pickle.load(f)
+ self.tags_binarizer = pickle.load(f)
+
+ self.tags_classifier = pickle.load(f)
+ self.correspondent_classifier = pickle.load(f)
+ self.document_type_classifier = pickle.load(f)
+ self.storage_path_classifier = pickle.load(f)
+ except Exception:
+ raise ClassifierModelCorruptError()
+
+ # Check for the warning about unpickling from differing versions
+ # and consider it incompatible
+ if len(w) > 0:
+ sk_learn_warning_url = (
+ "https://scikit-learn.org/stable/"
+ "model_persistence.html"
+ "#security-maintainability-limitations"
)
- else:
- try:
- self.data_hash = pickle.load(f)
- self.data_vectorizer = pickle.load(f)
- self.tags_binarizer = pickle.load(f)
-
- self.tags_classifier = pickle.load(f)
- self.correspondent_classifier = pickle.load(f)
- self.document_type_classifier = pickle.load(f)
- self.storage_path_classifier = pickle.load(f)
- except Exception:
- raise ClassifierModelCorruptError()
+ for warning in w:
+ if issubclass(warning.category, UserWarning):
+ w_msg = str(warning.message)
+ if sk_learn_warning_url in w_msg:
+ raise IncompatibleClassifierVersionError()
def save(self):
target_file = settings.MODEL_FILE
from pathlib import Path
from unittest import mock
+import documents
import pytest
from django.conf import settings
from django.test import override_settings
from django.test import TestCase
+from documents.classifier import ClassifierModelCorruptError
from documents.classifier import DocumentClassifier
from documents.classifier import IncompatibleClassifierVersionError
from documents.classifier import load_classifier
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
+ @override_settings(
+ MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
+ )
+ @mock.patch("documents.classifier.pickle.load")
+ def test_load_corrupt_file(self, patched_pickle_load):
+ """
+ GIVEN:
+ - Corrupted classifier pickle file
+ WHEN:
+ - An attempt is made to load the classifier
+ THEN:
+ - The ClassifierModelCorruptError is raised
+ """
+ # First load is the schema version
+ patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
+
+ with self.assertRaises(ClassifierModelCorruptError):
+ self.classifier.load()
+
+ @override_settings(
+ MODEL_FILE=os.path.join(
+ os.path.dirname(__file__),
+ "data",
+ "v1.0.2.model.pickle",
+ ),
+ )
+ def test_load_new_scikit_learn_version(self):
+ """
+ GIVEN:
+ - classifier pickle file created with a different scikit-learn version
+ WHEN:
+ - An attempt is made to load the classifier
+ THEN:
+ - The classifier reports the warning was captured and processed
+ """
+
+ with self.assertRaises(IncompatibleClassifierVersionError):
+ self.classifier.load()
+
def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create(
name="c1",