]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Updates the classifier to catch warnings from scikit-learn and rebuild the model...
authorTrenton Holmes <holmes.trenton@gmail.com>
Thu, 2 Jun 2022 20:58:38 +0000 (13:58 -0700)
committerJohann Bauer <bauerj@bauerj.eu>
Tue, 5 Jul 2022 06:20:35 +0000 (08:20 +0200)
src/documents/classifier.py
src/documents/tests/data/model.pickle
src/documents/tests/data/v1.0.2.model.pickle [new file with mode: 0644]
src/documents/tests/test_classifier.py

index 4bae1830b4b90379208d250f0fac43c066f34bac..1c2ccea07903253ed844b3ceff0b66a3fabe0c10 100644 (file)
@@ -4,6 +4,8 @@ import os
 import pickle
 import re
 import shutil
+import warnings
+from typing import Optional
 
 from django.conf import settings
 from documents.models import Document
@@ -21,13 +23,13 @@ class ClassifierModelCorruptError(Exception):
 logger = logging.getLogger("paperless.classifier")
 
 
-def preprocess_content(content):
+def preprocess_content(content: str) -> str:
     content = content.lower().strip()
     content = re.sub(r"\s+", " ", content)
     return content
 
 
-def load_classifier():
+def load_classifier() -> Optional["DocumentClassifier"]:
     if not os.path.isfile(settings.MODEL_FILE):
         logger.debug(
             "Document classification model does not exist (yet), not "
@@ -39,7 +41,11 @@ def load_classifier():
     try:
         classifier.load()
 
-    except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
+    except IncompatibleClassifierVersionError:
+        logger.info("Classifier version updated, will re-train")
+        os.unlink(settings.MODEL_FILE)
+        classifier = None
+    except ClassifierModelCorruptError:
         # there's something wrong with the model file.
         logger.exception(
             "Unrecoverable error while loading document "
@@ -59,13 +65,14 @@ def load_classifier():
 
 class DocumentClassifier:
 
+    # v7 - Updated scikit-learn package version
     # v8 - Added storage path classifier
     FORMAT_VERSION = 8
 
     def __init__(self):
         # hash of the training data. used to prevent re-training when the
         # training data has not changed.
-        self.data_hash = None
+        self.data_hash: Optional[bytes] = None
 
         self.data_vectorizer = None
         self.tags_binarizer = None
@@ -75,25 +82,41 @@ class DocumentClassifier:
         self.storage_path_classifier = None
 
     def load(self):
-        with open(settings.MODEL_FILE, "rb") as f:
-            schema_version = pickle.load(f)
-
-            if schema_version != self.FORMAT_VERSION:
-                raise IncompatibleClassifierVersionError(
-                    "Cannot load classifier, incompatible versions.",
+        # Catch warnings for processing
+        with warnings.catch_warnings(record=True) as w:
+            with open(settings.MODEL_FILE, "rb") as f:
+                schema_version = pickle.load(f)
+
+                if schema_version != self.FORMAT_VERSION:
+                    raise IncompatibleClassifierVersionError(
+                        "Cannot load classifier, incompatible versions.",
+                    )
+                else:
+                    try:
+                        self.data_hash = pickle.load(f)
+                        self.data_vectorizer = pickle.load(f)
+                        self.tags_binarizer = pickle.load(f)
+
+                        self.tags_classifier = pickle.load(f)
+                        self.correspondent_classifier = pickle.load(f)
+                        self.document_type_classifier = pickle.load(f)
+                        self.storage_path_classifier = pickle.load(f)
+                    except Exception:
+                        raise ClassifierModelCorruptError()
+
+            # Check for the warning about unpickling from differing versions
+            # and consider it incompatible
+            if len(w) > 0:
+                sk_learn_warning_url = (
+                    "https://scikit-learn.org/stable/"
+                    "model_persistence.html"
+                    "#security-maintainability-limitations"
                 )
-            else:
-                try:
-                    self.data_hash = pickle.load(f)
-                    self.data_vectorizer = pickle.load(f)
-                    self.tags_binarizer = pickle.load(f)
-
-                    self.tags_classifier = pickle.load(f)
-                    self.correspondent_classifier = pickle.load(f)
-                    self.document_type_classifier = pickle.load(f)
-                    self.storage_path_classifier = pickle.load(f)
-                except Exception:
-                    raise ClassifierModelCorruptError()
+                for warning in w:
+                    if issubclass(warning.category, UserWarning):
+                        w_msg = str(warning.message)
+                        if sk_learn_warning_url in w_msg:
+                            raise IncompatibleClassifierVersionError()
 
     def save(self):
         target_file = settings.MODEL_FILE
index 8a0e1829ce9902c0e4b50b23fc78e39515174103..ff88b8894eca097a7f185301ea2985c23ee0a6f9 100644 (file)
Binary files a/src/documents/tests/data/model.pickle and b/src/documents/tests/data/model.pickle differ
diff --git a/src/documents/tests/data/v1.0.2.model.pickle b/src/documents/tests/data/v1.0.2.model.pickle
new file mode 100644 (file)
index 0000000..8a0e182
Binary files /dev/null and b/src/documents/tests/data/v1.0.2.model.pickle differ
index dcc503f97fdb4302decd2922c8e3e7cb01bd97bf..cfa662c0234f69931329b8fea860701a39e9e56e 100644 (file)
@@ -3,10 +3,12 @@ import tempfile
 from pathlib import Path
 from unittest import mock
 
+import documents
 import pytest
 from django.conf import settings
 from django.test import override_settings
 from django.test import TestCase
+from documents.classifier import ClassifierModelCorruptError
 from documents.classifier import DocumentClassifier
 from documents.classifier import IncompatibleClassifierVersionError
 from documents.classifier import load_classifier
@@ -216,6 +218,45 @@ class TestClassifier(DirectoriesMixin, TestCase):
 
         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
 
+    @override_settings(
+        MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
+    )
+    @mock.patch("documents.classifier.pickle.load")
+    def test_load_corrupt_file(self, patched_pickle_load):
+        """
+        GIVEN:
+            - Corrupted classifier pickle file
+        WHEN:
+            - An attempt is made to load the classifier
+        THEN:
+            - The ClassifierModelCorruptError is raised
+        """
+        # First load is the schema version
+        patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
+
+        with self.assertRaises(ClassifierModelCorruptError):
+            self.classifier.load()
+
+    @override_settings(
+        MODEL_FILE=os.path.join(
+            os.path.dirname(__file__),
+            "data",
+            "v1.0.2.model.pickle",
+        ),
+    )
+    def test_load_new_scikit_learn_version(self):
+        """
+        GIVEN:
+            - classifier pickle file created with a different scikit-learn version
+        WHEN:
+            - An attempt is made to load the classifier
+        THEN:
+            - The classifier reports the warning was captured and processed
+        """
+
+        with self.assertRaises(IncompatibleClassifierVersionError):
+            self.classifier.load()
+
     def test_one_correspondent_predict(self):
         c1 = Correspondent.objects.create(
             name="c1",