]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
custom field regex matching feature-cf-matching-6932
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 21 Mar 2025 06:04:07 +0000 (23:04 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Fri, 21 Mar 2025 06:17:57 +0000 (23:17 -0700)
[ci skip]

src/documents/classifier.py
src/documents/matching.py
src/documents/signals/handlers.py

index 58c2058b5f737713c95f015a4500e567c3f9a8b9..c577790f24291235e515efe38b2cf0bc8c527b37 100644 (file)
@@ -526,7 +526,12 @@ class DocumentClassifier:
         else:
             return None
 
-    def predict_custom_fields(self, content: str) -> list[int]:
+    def predict_custom_fields(self, content: str) -> dict:
+        """
+        Custom fields are a bit different from the other classifiers, as we
+        need to predict the values for the fields, not just the field itself.
+        """
+        # TODO: can this return the value?
         from sklearn.utils.multiclass import type_of_target
 
         if self.custom_fields_classifier:
index 08cb5da770ba64619bcbf83f9eb35f7cc1bc9706..d09ae8074300e52b8e5f943b40a8c42c6b8fb932 100644 (file)
@@ -132,23 +132,48 @@ def match_storage_paths(document: Document, classifier: DocumentClassifier, user
     )
 
 
-def match_custom_fields(document: Document, classifier: DocumentClassifier, user=None):
+def match_custom_fields(
+    document: Document,
+    classifier: DocumentClassifier,
+    user=None,
+) -> dict:
+    """
+    Custom fields work differently, we need the values for the match as well.
+    """
+    # TODO: this needs to return values as well
     predicted_custom_field_ids = (
         classifier.predict_custom_fields(document.content) if classifier else []
     )
 
     fields = [instance.field for instance in document.custom_fields.all()]
 
-    return list(
-        filter(
-            lambda o: matches(o, document)
-            or (
-                o.matching_algorithm == MatchingModel.MATCH_AUTO
-                and o.pk in predicted_custom_field_ids
-            ),
-            fields,
-        ),
-    )
+    matched_fields = {}
+    for field in fields:
+        if field.matching_algorithm == MatchingModel.MATCH_AUTO:
+            if field.pk in predicted_custom_field_ids:
+                matched_fields[field] = None
+        elif field.matching_algorithm == MatchingModel.MATCH_REGEX:
+            try:
+                match = re.search(
+                    re.compile(field.matching_model.match),
+                    document.content,
+                )
+                if match:
+                    matched_fields[field] = match.group()
+            except re.error:
+                logger.error(
+                    f"Error while processing regular expression {field.matching_model.match}",
+                )
+                return False
+            if match:
+                log_reason(
+                    field.matching_model,
+                    document,
+                    f"the string {match.group()} matches the regular expression "
+                    f"{field.matching_model.match}",
+                )
+
+    return matched_fields
 
 
 def matches(matching_model: MatchingModel, document: Document):
index da54f456e81f90e350ddd113449fe2f70183bf94..c9fc90650c0c4cf8b160b66f5f5e170bf872c817 100644 (file)
@@ -322,11 +322,12 @@ def set_custom_fields(
     document: Document,
     logging_group=None,
     classifier: DocumentClassifier | None = None,
-    replace=False,
-    suggest=False,
     base_url=None,
     stdout=None,
     style_func=None,
+    *,
+    replace=False,
+    suggest=False,
     **kwargs,
 ):
     if replace:
@@ -336,7 +337,8 @@ def set_custom_fields(
 
     current_fields = set([instance.field for instance in document.custom_fields.all()])
 
-    matched_fields = matching.match_custom_fields(document, classifier)
+    matched_fields_w_values: dict = matching.match_custom_fields(document, classifier)
+    matched_fields = matched_fields_w_values.keys()
 
     relevant_fields = set(matched_fields) - current_fields
 
@@ -373,9 +375,17 @@ def set_custom_fields(
         )
 
         for field in relevant_fields:
+            args = {
+                "field": field,
+                "document": document,
+            }
+            if field.pk in matched_fields_w_values:
+                value_field_name = CustomFieldInstance.get_value_field_name(
+                    data_type=field.data_type,
+                )
+                args[value_field_name] = matched_fields_w_values[field.pk]
             CustomFieldInstance.objects.create(
-                field=field,
-                document=document,
+                **args,
             )