]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Try joblib fix-suggestions-memory
authorshamoon <4887959+shamoon@users.noreply.github.com>
Sun, 31 Aug 2025 21:18:31 +0000 (14:18 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Sun, 31 Aug 2025 21:18:31 +0000 (14:18 -0700)
src/documents/classifier.py

index 427f5ada0e8c29210ecb27b9b7232b017396904f..4324308c0b33036d682386d0280a987fd51bb0fb 100644 (file)
@@ -103,7 +103,8 @@ class DocumentClassifier:
     # v7 - Updated scikit-learn package version
     # v8 - Added storage path classifier
     # v9 - Changed from hashing to time/ids for re-train check
-    FORMAT_VERSION = 9
+    # v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
+    FORMAT_VERSION = 10
 
     def __init__(self) -> None:
         # last time a document changed and therefore training might be required
@@ -135,32 +136,51 @@ class DocumentClassifier:
         ).hexdigest()
 
     def load(self) -> None:
+        import joblib
         from sklearn.exceptions import InconsistentVersionWarning
 
         # Catch warnings for processing
         with warnings.catch_warnings(record=True) as w:
-            with Path(settings.MODEL_FILE).open("rb") as f:
-                schema_version = pickle.load(f)
+            try:
+                state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
+            except Exception as err:
+                # As a fallback, try to detect old pickle-based and mark incompatible
+                try:
+                    with Path(settings.MODEL_FILE).open("rb") as f:
+                        _ = pickle.load(f)
+                    raise IncompatibleClassifierVersionError(
+                        "Cannot load classifier, incompatible versions.",
+                    ) from err
+                except IncompatibleClassifierVersionError:
+                    raise
+                except Exception:
+                    # Not even a readable pickle header
+                    raise ClassifierModelCorruptError from err
 
-                if schema_version != self.FORMAT_VERSION:
+            try:
+                if (
+                    not isinstance(state, dict)
+                    or state.get("format_version") != self.FORMAT_VERSION
+                ):
                     raise IncompatibleClassifierVersionError(
                         "Cannot load classifier, incompatible versions.",
                     )
-                else:
-                    try:
-                        self.last_doc_change_time = pickle.load(f)
-                        self.last_auto_type_hash = pickle.load(f)
-
-                        self.data_vectorizer = pickle.load(f)
-                        self._update_data_vectorizer_hash()
-                        self.tags_binarizer = pickle.load(f)
-
-                        self.tags_classifier = pickle.load(f)
-                        self.correspondent_classifier = pickle.load(f)
-                        self.document_type_classifier = pickle.load(f)
-                        self.storage_path_classifier = pickle.load(f)
-                    except Exception as err:
-                        raise ClassifierModelCorruptError from err
+
+                self.last_doc_change_time = state.get("last_doc_change_time")
+                self.last_auto_type_hash = state.get("last_auto_type_hash")
+
+                self.data_vectorizer = state.get("data_vectorizer")
+                self._update_data_vectorizer_hash()
+                self.tags_binarizer = state.get("tags_binarizer")
+
+                self.tags_classifier = state.get("tags_classifier")
+                self.correspondent_classifier = state.get("correspondent_classifier")
+                self.document_type_classifier = state.get("document_type_classifier")
+                self.storage_path_classifier = state.get("storage_path_classifier")
+            except IncompatibleClassifierVersionError:
+                raise
+            except Exception as err:
+                raise ClassifierModelCorruptError from err
 
             # Check for the warning about unpickling from differing versions
             # and consider it incompatible
@@ -178,23 +198,24 @@ class DocumentClassifier:
                     raise IncompatibleClassifierVersionError("sklearn version update")
 
     def save(self) -> None:
-        target_file: Path = settings.MODEL_FILE
-        target_file_temp: Path = target_file.with_suffix(".pickle.part")
-
-        with target_file_temp.open("wb") as f:
-            pickle.dump(self.FORMAT_VERSION, f)
-
-            pickle.dump(self.last_doc_change_time, f)
-            pickle.dump(self.last_auto_type_hash, f)
+        import joblib
 
-            pickle.dump(self.data_vectorizer, f)
-
-            pickle.dump(self.tags_binarizer, f)
-            pickle.dump(self.tags_classifier, f)
-
-            pickle.dump(self.correspondent_classifier, f)
-            pickle.dump(self.document_type_classifier, f)
-            pickle.dump(self.storage_path_classifier, f)
+        target_file: Path = settings.MODEL_FILE
+        target_file_temp: Path = target_file.with_suffix(".joblib.part")
+
+        state = {
+            "format_version": self.FORMAT_VERSION,
+            "last_doc_change_time": self.last_doc_change_time,
+            "last_auto_type_hash": self.last_auto_type_hash,
+            "data_vectorizer": self.data_vectorizer,
+            "tags_binarizer": self.tags_binarizer,
+            "tags_classifier": self.tags_classifier,
+            "correspondent_classifier": self.correspondent_classifier,
+            "document_type_classifier": self.document_type_classifier,
+            "storage_path_classifier": self.storage_path_classifier,
+        }
+
+        joblib.dump(state, target_file_temp, compress=3)
 
         target_file_temp.rename(target_file)