]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Adding more typing around the classification and matching
authorTrenton Holmes <797416+stumpylog@users.noreply.github.com>
Sun, 23 Jul 2023 23:49:20 +0000 (16:49 -0700)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Wed, 26 Jul 2023 14:03:43 +0000 (07:03 -0700)
src/documents/classifier.py
src/documents/matching.py
src/documents/signals/handlers.py

index 3dd1a60aa998247f7d7fdeac36eca5185ad65a8b..5ed203934a83e5e2f26f45e6976600b80912f46b 100644 (file)
@@ -5,6 +5,7 @@ import re
 import warnings
 from datetime import datetime
 from hashlib import sha256
+from pathlib import Path
 from typing import Iterator
 from typing import List
 from typing import Optional
@@ -81,7 +82,7 @@ class DocumentClassifier:
         self._stemmer = None
         self._stop_words = None
 
-    def load(self):
+    def load(self) -> None:
         # Catch warnings for processing
         with warnings.catch_warnings(record=True) as w:
             with open(settings.MODEL_FILE, "rb") as f:
@@ -120,19 +121,20 @@ class DocumentClassifier:
                         raise IncompatibleClassifierVersionError
 
     def save(self):
-        target_file = settings.MODEL_FILE
-        target_file_temp = settings.MODEL_FILE.with_suffix(".pickle.part")
+        target_file: Path = settings.MODEL_FILE
+        target_file_temp = target_file.with_suffix(".pickle.part")
 
         with open(target_file_temp, "wb") as f:
             pickle.dump(self.FORMAT_VERSION, f)
+
             pickle.dump(self.last_doc_change_time, f)
             pickle.dump(self.last_auto_type_hash, f)
 
             pickle.dump(self.data_vectorizer, f)
 
             pickle.dump(self.tags_binarizer, f)
-
             pickle.dump(self.tags_classifier, f)
+
             pickle.dump(self.correspondent_classifier, f)
             pickle.dump(self.document_type_classifier, f)
             pickle.dump(self.storage_path_classifier, f)
@@ -380,7 +382,7 @@ class DocumentClassifier:
 
         return content
 
-    def predict_correspondent(self, content: str):
+    def predict_correspondent(self, content: str) -> Optional[int]:
         if self.correspondent_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             correspondent_id = self.correspondent_classifier.predict(X)
@@ -391,7 +393,7 @@ class DocumentClassifier:
         else:
             return None
 
-    def predict_document_type(self, content: str):
+    def predict_document_type(self, content: str) -> Optional[int]:
         if self.document_type_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             document_type_id = self.document_type_classifier.predict(X)
@@ -402,7 +404,7 @@ class DocumentClassifier:
         else:
             return None
 
-    def predict_tags(self, content: str):
+    def predict_tags(self, content: str) -> List[int]:
         from sklearn.utils.multiclass import type_of_target
 
         if self.tags_classifier:
@@ -423,7 +425,7 @@ class DocumentClassifier:
         else:
             return []
 
-    def predict_storage_path(self, content: str):
+    def predict_storage_path(self, content: str) -> Optional[int]:
         if self.storage_path_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             storage_path_id = self.storage_path_classifier.predict(X)
index a7ceb5a5ae06a709fdcc5271744dc538072d51e4..eb0f4f8b527f944f8bf18a57398e5613521b1282 100644 (file)
@@ -1,7 +1,9 @@
 import logging
 import re
 
+from documents.classifier import DocumentClassifier
 from documents.models import Correspondent
+from documents.models import Document
 from documents.models import DocumentType
 from documents.models import MatchingModel
 from documents.models import StoragePath
@@ -11,7 +13,7 @@ from documents.permissions import get_objects_for_user_owner_aware
 logger = logging.getLogger("paperless.matching")
 
 
-def log_reason(matching_model, document, reason):
+def log_reason(matching_model: MatchingModel, document: Document, reason: str):
     class_name = type(matching_model).__name__
     logger.debug(
         f"{class_name} {matching_model.name} matched on document "
@@ -19,7 +21,7 @@ def log_reason(matching_model, document, reason):
     )
 
 
-def match_correspondents(document, classifier, user=None):
+def match_correspondents(document: Document, classifier: DocumentClassifier, user=None):
     pred_id = classifier.predict_correspondent(document.content) if classifier else None
 
     if user is None and document.owner is not None:
@@ -43,7 +45,7 @@ def match_correspondents(document, classifier, user=None):
     )
 
 
-def match_document_types(document, classifier, user=None):
+def match_document_types(document: Document, classifier: DocumentClassifier, user=None):
     pred_id = classifier.predict_document_type(document.content) if classifier else None
 
     if user is None and document.owner is not None:
@@ -67,7 +69,7 @@ def match_document_types(document, classifier, user=None):
     )
 
 
-def match_tags(document, classifier, user=None):
+def match_tags(document: Document, classifier: DocumentClassifier, user=None):
     predicted_tag_ids = classifier.predict_tags(document.content) if classifier else []
 
     if user is None and document.owner is not None:
@@ -90,7 +92,7 @@ def match_tags(document, classifier, user=None):
     )
 
 
-def match_storage_paths(document, classifier, user=None):
+def match_storage_paths(document: Document, classifier: DocumentClassifier, user=None):
     pred_id = classifier.predict_storage_path(document.content) if classifier else None
 
     if user is None and document.owner is not None:
@@ -114,7 +116,7 @@ def match_storage_paths(document, classifier, user=None):
     )
 
 
-def matches(matching_model, document):
+def matches(matching_model: MatchingModel, document: Document):
     search_kwargs = {}
 
     document_content = document.content
index 4a39d98eab5f469b87d569731f703f63f32d4b0a..4e0d13c209f6c33151ff1c72b09a1cf86b3bd7f5 100644 (file)
@@ -1,6 +1,7 @@
 import logging
 import os
 import shutil
+from typing import Optional
 
 from celery import states
 from celery.signals import before_task_publish
@@ -21,6 +22,7 @@ from django.utils import timezone
 from filelock import FileLock
 
 from documents import matching
+from documents.classifier import DocumentClassifier
 from documents.file_handling import create_source_path_directory
 from documents.file_handling import delete_empty_directories
 from documents.file_handling import generate_unique_filename
@@ -33,7 +35,7 @@ from documents.permissions import get_objects_for_user_owner_aware
 logger = logging.getLogger("paperless.handlers")
 
 
-def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
+def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs):
     if document.owner is not None:
         tags = get_objects_for_user_owner_aware(
             document.owner,
@@ -48,9 +50,9 @@ def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
 
 def set_correspondent(
     sender,
-    document=None,
+    document: Document,
     logging_group=None,
-    classifier=None,
+    classifier: Optional[DocumentClassifier] = None,
     replace=False,
     use_first=True,
     suggest=False,
@@ -111,9 +113,9 @@ def set_correspondent(
 
 def set_document_type(
     sender,
-    document=None,
+    document: Document,
     logging_group=None,
-    classifier=None,
+    classifier: Optional[DocumentClassifier] = None,
     replace=False,
     use_first=True,
     suggest=False,
@@ -175,9 +177,9 @@ def set_document_type(
 
 def set_tags(
     sender,
-    document=None,
+    document: Document,
     logging_group=None,
-    classifier=None,
+    classifier: Optional[DocumentClassifier] = None,
     replace=False,
     suggest=False,
     base_url=None,
@@ -239,9 +241,9 @@ def set_tags(
 
 def set_storage_path(
     sender,
-    document=None,
+    document: Document,
     logging_group=None,
-    classifier=None,
+    classifier: Optional[DocumentClassifier] = None,
     replace=False,
     use_first=True,
     suggest=False,
@@ -491,7 +493,7 @@ def update_filename_and_move_files(sender, instance: Document, **kwargs):
             )
 
 
-def set_log_entry(sender, document=None, logging_group=None, **kwargs):
+def set_log_entry(sender, document: Document, logging_group=None, **kwargs):
     ct = ContentType.objects.get(model="document")
     user = User.objects.get(username="consumer")