self.correspondent_classifier = None
self.document_type_classifier = None
self.storage_path_classifier = None
+ self.custom_fields_binarizer = None
+ self.custom_fields_classifier = None
self._stemmer = None
self._stop_words = None
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
-
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
+ self.custom_fields_binarizer = pickle.load(f)
+ self.custom_fields_classifier = pickle.load(f)
except Exception as err:
raise ClassifierModelCorruptError from err
pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f)
+ pickle.dump(self.custom_fields_binarizer, f)
+ pickle.dump(self.custom_fields_classifier, f)
+
target_file_temp.rename(target_file)
def train(self) -> bool:
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
+ labels_custom_fields = []
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
hasher.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y)
- labels_tags_unique = {tag for tags in labels_tags for tag in tags}
+ custom_fields = sorted(
+ cf.pk
+ for cf in doc.custom_fields.filter(
+ field__matching_algorithm=MatchingModel.MATCH_AUTO,
+ )
+ )
+ for cf in custom_fields:
+ hasher.update(cf.to_bytes(4, "little", signed=True))
+ labels_custom_fields.append(custom_fields)
+ labels_tags_unique = {tag for tags in labels_tags for tag in tags}
num_tags = len(labels_tags_unique)
+ labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs}
+ num_custom_fields = len(labels_custom_fields_unique)
+
# Check if retraining is actually required.
# A document has been updated since the classifier was trained
- # New auto tags, types, correspondent, storage paths exist
+ # New auto tags, types, correspondent, storage paths or custom fields exist
latest_doc_change = docs_queryset.latest("modified").modified
if (
self.last_doc_change_time is not None
logger.debug(
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
- f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
+ f"{num_document_types} document type(s), {num_storage_paths} storage path(s), "
+ f"{num_custom_fields} custom field(s)",
)
from sklearn.feature_extraction.text import CountVectorizer
"There are no storage paths. Not training storage path classifier.",
)
+ if num_custom_fields > 0:
+ logger.debug("Training custom fields classifier...")
+
+ if num_custom_fields == 1:
+ # Special case where only one custom field has auto:
+ # Fallback to binary classification.
+ labels_custom_fields = [
+ label[0] if len(label) == 1 else -1
+ for label in labels_custom_fields
+ ]
+ self.custom_fields_binarizer = LabelBinarizer()
+ labels_custom_fields_vectorized = (
+ self.custom_fields_binarizer.fit_transform(
+ labels_custom_fields,
+ ).ravel()
+ )
+ else:
+ self.custom_fields_binarizer = MultiLabelBinarizer()
+ labels_custom_fields_vectorized = (
+ self.custom_fields_binarizer.fit_transform(labels_custom_fields)
+ )
+
+ self.custom_fields_classifier = MLPClassifier(tol=0.01)
+ self.custom_fields_classifier.fit(
+ data_vectorized,
+ labels_custom_fields_vectorized,
+ )
+ else:
+ self.custom_fields_classifier = None
+ logger.debug(
+ "There are no custom fields. Not training custom fields classifier.",
+ )
+
self.last_doc_change_time = latest_doc_change
self.last_auto_type_hash = hasher.digest()
return None
else:
return None
+
+ def predict_custom_fields(self, content: str) -> list[int]:
+ from sklearn.utils.multiclass import type_of_target
+
+ if self.custom_fields_classifier:
+ X = self.data_vectorizer.transform([self.preprocess_content(content)])
+ y = self.custom_fields_classifier.predict(X)
+ custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0]
+ if type_of_target(y).startswith("multilabel"):
+ # the usual case when there are multiple custom fields.
+ return list(custom_fields_ids)
+ elif type_of_target(y) == "binary" and custom_fields_ids != -1:
+ # This is for when we have binary classification with only one
+ # custom field and the result is to assign this custom field.
+ return [custom_fields_ids]
+ else:
+ # Usually binary as well with -1 as the result, but we're
+ # going to catch everything else here as well.
+ return []
+ else:
+ return []