]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Add custom fields to classifier
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 13 Dec 2024 21:39:19 +0000 (13:39 -0800)
committershamoon <4887959+shamoon@users.noreply.github.com>
Thu, 20 Mar 2025 23:21:34 +0000 (16:21 -0700)
src/documents/classifier.py

index 728c8322898377c1b319ff6d6aa0d5258dd52720..58c2058b5f737713c95f015a4500e567c3f9a8b9 100644 (file)
@@ -97,6 +97,8 @@ class DocumentClassifier:
         self.correspondent_classifier = None
         self.document_type_classifier = None
         self.storage_path_classifier = None
+        self.custom_fields_binarizer = None
+        self.custom_fields_classifier = None
 
         self._stemmer = None
         self._stop_words = None
@@ -120,11 +122,12 @@ class DocumentClassifier:
 
                         self.data_vectorizer = pickle.load(f)
                         self.tags_binarizer = pickle.load(f)
-
                         self.tags_classifier = pickle.load(f)
                         self.correspondent_classifier = pickle.load(f)
                         self.document_type_classifier = pickle.load(f)
                         self.storage_path_classifier = pickle.load(f)
+                        self.custom_fields_binarizer = pickle.load(f)
+                        self.custom_fields_classifier = pickle.load(f)
                     except Exception as err:
                         raise ClassifierModelCorruptError from err
 
@@ -162,6 +165,9 @@ class DocumentClassifier:
             pickle.dump(self.document_type_classifier, f)
             pickle.dump(self.storage_path_classifier, f)
 
+            pickle.dump(self.custom_fields_binarizer, f)
+            pickle.dump(self.custom_fields_classifier, f)
+
         target_file_temp.rename(target_file)
 
     def train(self) -> bool:
@@ -183,6 +189,7 @@ class DocumentClassifier:
         labels_correspondent = []
         labels_document_type = []
         labels_storage_path = []
+        labels_custom_fields = []
 
         # Step 1: Extract and preprocess training data from the database.
         logger.debug("Gathering data from database...")
@@ -218,13 +225,25 @@ class DocumentClassifier:
             hasher.update(y.to_bytes(4, "little", signed=True))
             labels_storage_path.append(y)
 
-        labels_tags_unique = {tag for tags in labels_tags for tag in tags}
+            custom_fields = sorted(
+                cf.pk
+                for cf in doc.custom_fields.filter(
+                    field__matching_algorithm=MatchingModel.MATCH_AUTO,
+                )
+            )
+            for cf in custom_fields:
+                hasher.update(cf.to_bytes(4, "little", signed=True))
+            labels_custom_fields.append(custom_fields)
 
+        labels_tags_unique = {tag for tags in labels_tags for tag in tags}
         num_tags = len(labels_tags_unique)
 
+        labels_custom_fields_unique = {cf for cfs in labels_custom_fields for cf in cfs}
+        num_custom_fields = len(labels_custom_fields_unique)
+
         # Check if retraining is actually required.
         # A document has been updated since the classifier was trained
-        # New auto tags, types, correspondent, storage paths exist
+        # New auto tags, types, correspondent, storage paths or custom fields exist
         latest_doc_change = docs_queryset.latest("modified").modified
         if (
             self.last_doc_change_time is not None
@@ -253,7 +272,8 @@ class DocumentClassifier:
 
         logger.debug(
             f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
-            f"{num_document_types} document type(s). {num_storage_paths} storage path(s)",
+            f"{num_document_types} document type(s), {num_storage_paths} storage path(s), "
+            f"{num_custom_fields} custom field(s)",
         )
 
         from sklearn.feature_extraction.text import CountVectorizer
@@ -345,6 +365,39 @@ class DocumentClassifier:
                 "There are no storage paths. Not training storage path classifier.",
             )
 
+        if num_custom_fields > 0:
+            logger.debug("Training custom fields classifier...")
+
+            if num_custom_fields == 1:
+                # Special case where only one custom field has auto:
+                # Fallback to binary classification.
+                labels_custom_fields = [
+                    label[0] if len(label) == 1 else -1
+                    for label in labels_custom_fields
+                ]
+                self.custom_fields_binarizer = LabelBinarizer()
+                labels_custom_fields_vectorized = (
+                    self.custom_fields_binarizer.fit_transform(
+                        labels_custom_fields,
+                    ).ravel()
+                )
+            else:
+                self.custom_fields_binarizer = MultiLabelBinarizer()
+                labels_custom_fields_vectorized = (
+                    self.custom_fields_binarizer.fit_transform(labels_custom_fields)
+                )
+
+            self.custom_fields_classifier = MLPClassifier(tol=0.01)
+            self.custom_fields_classifier.fit(
+                data_vectorized,
+                labels_custom_fields_vectorized,
+            )
+        else:
+            self.custom_fields_classifier = None
+            logger.debug(
+                "There are no custom fields. Not training custom fields classifier.",
+            )
+
         self.last_doc_change_time = latest_doc_change
         self.last_auto_type_hash = hasher.digest()
 
@@ -472,3 +525,24 @@ class DocumentClassifier:
                 return None
         else:
             return None
+
+    def predict_custom_fields(self, content: str) -> list[int]:
+        from sklearn.utils.multiclass import type_of_target
+
+        if self.custom_fields_classifier:
+            X = self.data_vectorizer.transform([self.preprocess_content(content)])
+            y = self.custom_fields_classifier.predict(X)
+            custom_fields_ids = self.custom_fields_binarizer.inverse_transform(y)[0]
+            if type_of_target(y).startswith("multilabel"):
+                # the usual case when there are multiple custom fields.
+                return list(custom_fields_ids)
+            elif type_of_target(y) == "binary" and custom_fields_ids != -1:
+                # This is for when we have binary classification with only one
+                # custom field and the result is to assign this custom field.
+                return [custom_fields_ids]
+            else:
+                # Usually binary as well with -1 as the result, but we're
+                # going to catch everything else here as well.
+                return []
+        else:
+            return []