From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sun, 31 Aug 2025 21:18:31 +0000 (-0700) Subject: Try joblib X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Ffix-suggestions-memory;p=thirdparty%2Fpaperless-ngx.git Try joblib --- diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 427f5ada0..4324308c0 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -103,7 +103,8 @@ class DocumentClassifier: # v7 - Updated scikit-learn package version # v8 - Added storage path classifier # v9 - Changed from hashing to time/ids for re-train check - FORMAT_VERSION = 9 + # v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes + FORMAT_VERSION = 10 def __init__(self) -> None: # last time a document changed and therefore training might be required @@ -135,32 +136,51 @@ class DocumentClassifier: ).hexdigest() def load(self) -> None: + import joblib from sklearn.exceptions import InconsistentVersionWarning # Catch warnings for processing with warnings.catch_warnings(record=True) as w: - with Path(settings.MODEL_FILE).open("rb") as f: - schema_version = pickle.load(f) + try: + state = joblib.load(settings.MODEL_FILE, mmap_mode="r") + except Exception as err: + # As a fallback, try to detect old pickle-based and mark incompatible + try: + with Path(settings.MODEL_FILE).open("rb") as f: + _ = pickle.load(f) + raise IncompatibleClassifierVersionError( + "Cannot load classifier, incompatible versions.", + ) from err + except IncompatibleClassifierVersionError: + raise + except Exception: + # Not even a readable pickle header + raise ClassifierModelCorruptError from err - if schema_version != self.FORMAT_VERSION: + try: + if ( + not isinstance(state, dict) + or state.get("format_version") != self.FORMAT_VERSION + ): raise IncompatibleClassifierVersionError( "Cannot load classifier, incompatible versions.", ) - else: - try: - self.last_doc_change_time = pickle.load(f) - self.last_auto_type_hash = pickle.load(f) - - self.data_vectorizer = pickle.load(f) - self._update_data_vectorizer_hash() - self.tags_binarizer = pickle.load(f) - - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.storage_path_classifier = pickle.load(f) - except Exception as err: - raise ClassifierModelCorruptError from err + + self.last_doc_change_time = state.get("last_doc_change_time") + self.last_auto_type_hash = state.get("last_auto_type_hash") + + self.data_vectorizer = state.get("data_vectorizer") + self._update_data_vectorizer_hash() + self.tags_binarizer = state.get("tags_binarizer") + + self.tags_classifier = state.get("tags_classifier") + self.correspondent_classifier = state.get("correspondent_classifier") + self.document_type_classifier = state.get("document_type_classifier") + self.storage_path_classifier = state.get("storage_path_classifier") + except IncompatibleClassifierVersionError: + raise + except Exception as err: + raise ClassifierModelCorruptError from err # Check for the warning about unpickling from differing versions # and consider it incompatible @@ -178,23 +198,24 @@ class DocumentClassifier: raise IncompatibleClassifierVersionError("sklearn version update") def save(self) -> None: - target_file: Path = settings.MODEL_FILE - target_file_temp: Path = target_file.with_suffix(".pickle.part") - - with target_file_temp.open("wb") as f: - pickle.dump(self.FORMAT_VERSION, f) - - pickle.dump(self.last_doc_change_time, f) - pickle.dump(self.last_auto_type_hash, f) + import joblib - pickle.dump(self.data_vectorizer, f) - - pickle.dump(self.tags_binarizer, f) - pickle.dump(self.tags_classifier, f) - - pickle.dump(self.correspondent_classifier, f) - pickle.dump(self.document_type_classifier, f) - pickle.dump(self.storage_path_classifier, f) + target_file: Path = settings.MODEL_FILE + target_file_temp: Path = target_file.with_suffix(".joblib.part") + + state = { + "format_version": self.FORMAT_VERSION, + "last_doc_change_time": self.last_doc_change_time, + "last_auto_type_hash": self.last_auto_type_hash, + "data_vectorizer": self.data_vectorizer, + "tags_binarizer": self.tags_binarizer, + "tags_classifier": self.tags_classifier, + "correspondent_classifier": self.correspondent_classifier, + "document_type_classifier": self.document_type_classifier, + "storage_path_classifier": self.storage_path_classifier, + } + + joblib.dump(state, target_file_temp, compress=3) target_file_temp.rename(target_file)