# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
# v9 - Changed from hashing to time/ids for re-train check
- FORMAT_VERSION = 9
+ # v10 - Switch persistence to joblib with memory-mapping to reduce load-time memory spikes
+ FORMAT_VERSION = 10
def __init__(self) -> None:
# last time a document changed and therefore training might be required
).hexdigest()
def load(self) -> None:
+ import joblib
from sklearn.exceptions import InconsistentVersionWarning
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
- with Path(settings.MODEL_FILE).open("rb") as f:
- schema_version = pickle.load(f)
+ try:
+ state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
+ except Exception as err:
+ # As a fallback, try to detect old pickle-based and mark incompatible
+ try:
+ with Path(settings.MODEL_FILE).open("rb") as f:
+ _ = pickle.load(f)
+ raise IncompatibleClassifierVersionError(
+ "Cannot load classifier, incompatible versions.",
+ ) from err
+ except IncompatibleClassifierVersionError:
+ raise
+ except Exception:
+ # Not even a readable pickle header
+ raise ClassifierModelCorruptError from err
- if schema_version != self.FORMAT_VERSION:
+ try:
+ if (
+ not isinstance(state, dict)
+ or state.get("format_version") != self.FORMAT_VERSION
+ ):
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
- else:
- try:
- self.last_doc_change_time = pickle.load(f)
- self.last_auto_type_hash = pickle.load(f)
-
- self.data_vectorizer = pickle.load(f)
- self._update_data_vectorizer_hash()
- self.tags_binarizer = pickle.load(f)
-
- self.tags_classifier = pickle.load(f)
- self.correspondent_classifier = pickle.load(f)
- self.document_type_classifier = pickle.load(f)
- self.storage_path_classifier = pickle.load(f)
- except Exception as err:
- raise ClassifierModelCorruptError from err
+
+ self.last_doc_change_time = state.get("last_doc_change_time")
+ self.last_auto_type_hash = state.get("last_auto_type_hash")
+
+ self.data_vectorizer = state.get("data_vectorizer")
+ self._update_data_vectorizer_hash()
+ self.tags_binarizer = state.get("tags_binarizer")
+
+ self.tags_classifier = state.get("tags_classifier")
+ self.correspondent_classifier = state.get("correspondent_classifier")
+ self.document_type_classifier = state.get("document_type_classifier")
+ self.storage_path_classifier = state.get("storage_path_classifier")
+ except IncompatibleClassifierVersionError:
+ raise
+ except Exception as err:
+ raise ClassifierModelCorruptError from err
# Check for the warning about unpickling from differing versions
# and consider it incompatible
raise IncompatibleClassifierVersionError("sklearn version update")
def save(self) -> None:
- target_file: Path = settings.MODEL_FILE
- target_file_temp: Path = target_file.with_suffix(".pickle.part")
-
- with target_file_temp.open("wb") as f:
- pickle.dump(self.FORMAT_VERSION, f)
-
- pickle.dump(self.last_doc_change_time, f)
- pickle.dump(self.last_auto_type_hash, f)
+ import joblib
- pickle.dump(self.data_vectorizer, f)
-
- pickle.dump(self.tags_binarizer, f)
- pickle.dump(self.tags_classifier, f)
-
- pickle.dump(self.correspondent_classifier, f)
- pickle.dump(self.document_type_classifier, f)
- pickle.dump(self.storage_path_classifier, f)
+ target_file: Path = settings.MODEL_FILE
+ target_file_temp: Path = target_file.with_suffix(".joblib.part")
+
+ state = {
+ "format_version": self.FORMAT_VERSION,
+ "last_doc_change_time": self.last_doc_change_time,
+ "last_auto_type_hash": self.last_auto_type_hash,
+ "data_vectorizer": self.data_vectorizer,
+ "tags_binarizer": self.tags_binarizer,
+ "tags_classifier": self.tags_classifier,
+ "correspondent_classifier": self.correspondent_classifier,
+ "document_type_classifier": self.document_type_classifier,
+ "storage_path_classifier": self.storage_path_classifier,
+ }
+
+ joblib.dump(state, target_file_temp, compress=3)
target_file_temp.rename(target_file)