]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Chore: Switch from os.path to pathlib.Path (#8325)
authorSebastian Steinbeißer <33968289+gothicVI@users.noreply.github.com>
Mon, 6 Jan 2025 20:12:27 +0000 (21:12 +0100)
committerGitHub <noreply@github.com>
Mon, 6 Jan 2025 20:12:27 +0000 (12:12 -0800)
---------

Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
.ruff.toml
src/documents/barcodes.py
src/documents/classifier.py
src/documents/index.py
src/documents/management/commands/decrypt_documents.py
src/documents/management/commands/document_importer.py
src/documents/migrations/1037_webp_encrypted_thumbnail_conversion.py
src/documents/tests/test_api_status.py
src/documents/tests/test_classifier.py
src/documents/tests/test_management.py
src/documents/views.py

index d805f3add3452468b3b0668d8755adbe0f593fe8..96ee7430b10c234cff99ac3e232d15cd2c96249c 100644 (file)
@@ -38,20 +38,14 @@ ignore = ["DJ001", "SIM105", "RUF012"]
 [lint.per-file-ignores]
 ".github/scripts/*.py" = ["E501", "INP001", "SIM117"]
 "docker/wait-for-redis.py" = ["INP001", "T201"]
-"src/documents/barcodes.py" = ["PTH"]  # TODO Enable & remove
-"src/documents/classifier.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/consumer.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/file_handling.py" = ["PTH"]  # TODO Enable & remove
-"src/documents/index.py" = ["PTH"]  # TODO Enable & remove
-"src/documents/management/commands/decrypt_documents.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/management/commands/document_consumer.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/management/commands/document_exporter.py" = ["PTH"]  # TODO Enable & remove
-"src/documents/management/commands/document_importer.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/migrations/0012_auto_20160305_0040.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/migrations/0014_document_checksum.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/migrations/1003_mime_types.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/migrations/1012_fix_archive_files.py" = ["PTH"]  # TODO Enable & remove
-"src/documents/migrations/1037_webp_encrypted_thumbnail_conversion.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/models.py" = ["SIM115", "PTH"]  # TODO PTH Enable & remove
 "src/documents/parsers.py" = ["PTH"]  # TODO Enable & remove
 "src/documents/signals/handlers.py" = ["PTH"]  # TODO Enable & remove
index 132e853b0b93deda2e61c4d3cb3e01300d00f1ac..4fe0670af90804b545dd20cc8cc3bf1a6a90d50e 100644 (file)
@@ -3,6 +3,7 @@ import re
 import tempfile
 from dataclasses import dataclass
 from pathlib import Path
+from typing import TYPE_CHECKING
 
 from django.conf import settings
 from pdf2image import convert_from_path
@@ -21,6 +22,9 @@ from documents.utils import copy_basic_file_stats
 from documents.utils import copy_file_with_basic_stats
 from documents.utils import maybe_override_pixel_limit
 
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
 logger = logging.getLogger("paperless.barcodes")
 
 
@@ -61,7 +65,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
           - Barcode support is enabled and the mime type is supported
         """
         if settings.CONSUMER_BARCODE_TIFF_SUPPORT:
-            supported_mimes = {"application/pdf", "image/tiff"}
+            supported_mimes: set[str] = {"application/pdf", "image/tiff"}
         else:
             supported_mimes = {"application/pdf"}
 
@@ -71,16 +75,16 @@ class BarcodePlugin(ConsumeTaskPlugin):
             or settings.CONSUMER_ENABLE_TAG_BARCODE
         ) and self.input_doc.mime_type in supported_mimes
 
-    def setup(self):
+    def setup(self) -> None:
         self.temp_dir = tempfile.TemporaryDirectory(
             dir=self.base_tmp_dir,
             prefix="barcode",
         )
-        self.pdf_file = self.input_doc.original_file
+        self.pdf_file: Path = self.input_doc.original_file
         self._tiff_conversion_done = False
         self.barcodes: list[Barcode] = []
 
-    def run(self) -> str | None:
+    def run(self) -> None:
         # Some operations may use PIL, override pixel setting if needed
         maybe_override_pixel_limit()
 
@@ -158,7 +162,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
     def cleanup(self) -> None:
         self.temp_dir.cleanup()
 
-    def convert_from_tiff_to_pdf(self):
+    def convert_from_tiff_to_pdf(self) -> None:
         """
         May convert a TIFF image into a PDF, if the input is a TIFF and
         the TIFF has not been made into a PDF
@@ -223,7 +227,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
 
         # Choose the library for reading
         if settings.CONSUMER_BARCODE_SCANNER == "PYZBAR":
-            reader = self.read_barcodes_pyzbar
+            reader: Callable[[Image.Image], list[str]] = self.read_barcodes_pyzbar
             logger.debug("Scanning for barcodes using PYZBAR")
         else:
             reader = self.read_barcodes_zxing
@@ -236,7 +240,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
             logger.debug(f"PDF has {num_of_pages} pages")
 
             # Get limit from configuration
-            barcode_max_pages = (
+            barcode_max_pages: int = (
                 num_of_pages
                 if settings.CONSUMER_BARCODE_MAX_PAGES == 0
                 else settings.CONSUMER_BARCODE_MAX_PAGES
@@ -311,7 +315,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
         self.detect()
 
         # get the first barcode that starts with CONSUMER_ASN_BARCODE_PREFIX
-        asn_text = next(
+        asn_text: str | None = next(
             (x.value for x in self.barcodes if x.is_asn),
             None,
         )
@@ -333,36 +337,36 @@ class BarcodePlugin(ConsumeTaskPlugin):
         return asn
 
     @property
-    def tags(self) -> list[int] | None:
+    def tags(self) -> list[int]:
         """
         Search the parsed barcodes for any tags.
         Returns the detected tag ids (or empty list)
         """
-        tags = []
+        tags: list[int] = []
 
         # Ensure the barcodes have been read
         self.detect()
 
         for x in self.barcodes:
-            tag_texts = x.value
+            tag_texts: str = x.value
 
             for raw in tag_texts.split(","):
                 try:
-                    tag = None
+                    tag_str: str | None = None
                     for regex in settings.CONSUMER_TAG_BARCODE_MAPPING:
                         if re.match(regex, raw, flags=re.IGNORECASE):
                             sub = settings.CONSUMER_TAG_BARCODE_MAPPING[regex]
-                            tag = (
+                            tag_str = (
                                 re.sub(regex, sub, raw, flags=re.IGNORECASE)
                                 if sub
                                 else raw
                             )
                             break
 
-                    if tag:
+                    if tag_str:
                         tag, _ = Tag.objects.get_or_create(
-                            name__iexact=tag,
-                            defaults={"name": tag},
+                            name__iexact=tag_str,
+                            defaults={"name": tag_str},
                         )
 
                         logger.debug(
@@ -413,7 +417,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
         """
 
         document_paths = []
-        fname = self.input_doc.original_file.stem
+        fname: str = self.input_doc.original_file.stem
         with Pdf.open(self.pdf_file) as input_pdf:
             # Start with an empty document
             current_document: list[Page] = []
@@ -432,7 +436,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
                 logger.debug(f"Starting new document at idx {idx}")
                 current_document = []
                 documents.append(current_document)
-                keep_page = pages_to_split_on[idx]
+                keep_page: bool = pages_to_split_on[idx]
                 if keep_page:
                     # Keep the page
                     # (new document is started by asn barcode)
@@ -451,7 +455,7 @@ class BarcodePlugin(ConsumeTaskPlugin):
 
                 logger.debug(f"pdf no:{doc_idx} has {len(dst.pages)} pages")
                 savepath = Path(self.temp_dir.name) / output_filename
-                with open(savepath, "wb") as out:
+                with savepath.open("wb") as out:
                     dst.save(out)
 
                 copy_basic_file_stats(self.input_doc.original_file, savepath)
index 26a1ae4782b1530db005f9c99a7cdfa1113e58ef..b46d0e138f9bde500bf3e95fccbb1dda04f7aa60 100644 (file)
@@ -1,16 +1,17 @@
 import logging
-import os
 import pickle
 import re
 import warnings
 from collections.abc import Iterator
 from hashlib import sha256
+from pathlib import Path
 from typing import TYPE_CHECKING
 from typing import Optional
 
 if TYPE_CHECKING:
     from datetime import datetime
-    from pathlib import Path
+
+    from numpy import ndarray
 
 from django.conf import settings
 from django.core.cache import cache
@@ -28,7 +29,7 @@ logger = logging.getLogger("paperless.classifier")
 
 class IncompatibleClassifierVersionError(Exception):
     def __init__(self, message: str, *args: object) -> None:
-        self.message = message
+        self.message: str = message
         super().__init__(*args)
 
 
@@ -36,8 +37,8 @@ class ClassifierModelCorruptError(Exception):
     pass
 
 
-def load_classifier() -> Optional["DocumentClassifier"]:
-    if not os.path.isfile(settings.MODEL_FILE):
+def load_classifier(*, raise_exception: bool = False) -> Optional["DocumentClassifier"]:
+    if not settings.MODEL_FILE.is_file():
         logger.debug(
             "Document classification model does not exist (yet), not "
             "performing automatic matching.",
@@ -50,22 +51,30 @@ def load_classifier() -> Optional["DocumentClassifier"]:
 
     except IncompatibleClassifierVersionError as e:
         logger.info(f"Classifier version incompatible: {e.message}, will re-train")
-        os.unlink(settings.MODEL_FILE)
+        Path(settings.MODEL_FILE).unlink()
         classifier = None
-    except ClassifierModelCorruptError:
+        if raise_exception:
+            raise e
+    except ClassifierModelCorruptError as e:
         # there's something wrong with the model file.
         logger.exception(
             "Unrecoverable error while loading document "
             "classification model, deleting model file.",
         )
-        os.unlink(settings.MODEL_FILE)
+        Path(settings.MODEL_FILE).unlink
         classifier = None
-    except OSError:
+        if raise_exception:
+            raise e
+    except OSError as e:
         logger.exception("IO error while loading document classification model")
         classifier = None
-    except Exception:  # pragma: no cover
+        if raise_exception:
+            raise e
+    except Exception as e:  # pragma: no cover
         logger.exception("Unknown error while loading document classification model")
         classifier = None
+        if raise_exception:
+            raise e
 
     return classifier
 
@@ -76,7 +85,7 @@ class DocumentClassifier:
     # v9 - Changed from hashing to time/ids for re-train check
     FORMAT_VERSION = 9
 
-    def __init__(self):
+    def __init__(self) -> None:
         # last time a document changed and therefore training might be required
         self.last_doc_change_time: datetime | None = None
         # Hash of primary keys of AUTO matching values last used in training
@@ -95,7 +104,7 @@ class DocumentClassifier:
     def load(self) -> None:
         # Catch warnings for processing
         with warnings.catch_warnings(record=True) as w:
-            with open(settings.MODEL_FILE, "rb") as f:
+            with Path(settings.MODEL_FILE).open("rb") as f:
                 schema_version = pickle.load(f)
 
                 if schema_version != self.FORMAT_VERSION:
@@ -132,11 +141,11 @@ class DocumentClassifier:
                 ):
                     raise IncompatibleClassifierVersionError("sklearn version update")
 
-    def save(self):
+    def save(self) -> None:
         target_file: Path = settings.MODEL_FILE
-        target_file_temp = target_file.with_suffix(".pickle.part")
+        target_file_temp: Path = target_file.with_suffix(".pickle.part")
 
-        with open(target_file_temp, "wb") as f:
+        with target_file_temp.open("wb") as f:
             pickle.dump(self.FORMAT_VERSION, f)
 
             pickle.dump(self.last_doc_change_time, f)
@@ -153,7 +162,7 @@ class DocumentClassifier:
 
         target_file_temp.rename(target_file)
 
-    def train(self):
+    def train(self) -> bool:
         # Get non-inbox documents
         docs_queryset = (
             Document.objects.exclude(
@@ -190,7 +199,7 @@ class DocumentClassifier:
             hasher.update(y.to_bytes(4, "little", signed=True))
             labels_correspondent.append(y)
 
-            tags = sorted(
+            tags: list[int] = sorted(
                 tag.pk
                 for tag in doc.tags.filter(
                     matching_algorithm=MatchingModel.MATCH_AUTO,
@@ -236,9 +245,9 @@ class DocumentClassifier:
         # union with {-1} accounts for cases where all documents have
         # correspondents and types assigned, so -1 isn't part of labels_x, which
         # it usually is.
-        num_correspondents = len(set(labels_correspondent) | {-1}) - 1
-        num_document_types = len(set(labels_document_type) | {-1}) - 1
-        num_storage_paths = len(set(labels_storage_path) | {-1}) - 1
+        num_correspondents: int = len(set(labels_correspondent) | {-1}) - 1
+        num_document_types: int = len(set(labels_document_type) | {-1}) - 1
+        num_storage_paths: int = len(set(labels_storage_path) | {-1}) - 1
 
         logger.debug(
             f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
@@ -266,7 +275,9 @@ class DocumentClassifier:
             min_df=0.01,
         )
 
-        data_vectorized = self.data_vectorizer.fit_transform(content_generator())
+        data_vectorized: ndarray = self.data_vectorizer.fit_transform(
+            content_generator(),
+        )
 
         # See the notes here:
         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
@@ -284,7 +295,7 @@ class DocumentClassifier:
                     label[0] if len(label) == 1 else -1 for label in labels_tags
                 ]
                 self.tags_binarizer = LabelBinarizer()
-                labels_tags_vectorized = self.tags_binarizer.fit_transform(
+                labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
                     labels_tags,
                 ).ravel()
             else:
index eacd1f99b77738bddfb48ed5cf605e74448a377f..4c5afb5054b025d21e13f8e21182b175a6f60662 100644 (file)
@@ -1,11 +1,11 @@
 import logging
 import math
-import os
 from collections import Counter
 from contextlib import contextmanager
 from datetime import datetime
 from datetime import timezone
 from shutil import rmtree
+from typing import Literal
 
 from django.conf import settings
 from django.db.models import QuerySet
@@ -47,7 +47,7 @@ from documents.models import User
 logger = logging.getLogger("paperless.index")
 
 
-def get_schema():
+def get_schema() -> Schema:
     return Schema(
         id=NUMERIC(stored=True, unique=True),
         title=TEXT(sortable=True),
@@ -93,7 +93,7 @@ def open_index(recreate=False) -> FileIndex:
         logger.exception("Error while opening the index, recreating.")
 
     # create_in doesn't handle corrupted indexes very well, remove the directory entirely first
-    if os.path.isdir(settings.INDEX_DIR):
+    if settings.INDEX_DIR.is_dir():
         rmtree(settings.INDEX_DIR)
     settings.INDEX_DIR.mkdir(parents=True, exist_ok=True)
 
@@ -123,7 +123,7 @@ def open_index_searcher() -> Searcher:
         searcher.close()
 
 
-def update_document(writer: AsyncWriter, doc: Document):
+def update_document(writer: AsyncWriter, doc: Document) -> None:
     tags = ",".join([t.name for t in doc.tags.all()])
     tags_ids = ",".join([str(t.id) for t in doc.tags.all()])
     notes = ",".join([str(c.note) for c in Note.objects.filter(document=doc)])
@@ -133,7 +133,7 @@ def update_document(writer: AsyncWriter, doc: Document):
     custom_fields_ids = ",".join(
         [str(f.field.id) for f in CustomFieldInstance.objects.filter(document=doc)],
     )
-    asn = doc.archive_serial_number
+    asn: int | None = doc.archive_serial_number
     if asn is not None and (
         asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
         or asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
@@ -149,7 +149,7 @@ def update_document(writer: AsyncWriter, doc: Document):
         doc,
         only_with_perms_in=["view_document"],
     )
-    viewer_ids = ",".join([str(u.id) for u in users_with_perms])
+    viewer_ids: str = ",".join([str(u.id) for u in users_with_perms])
     writer.update_document(
         id=doc.pk,
         title=doc.title,
@@ -187,20 +187,20 @@ def update_document(writer: AsyncWriter, doc: Document):
     )
 
 
-def remove_document(writer: AsyncWriter, doc: Document):
+def remove_document(writer: AsyncWriter, doc: Document) -> None:
     remove_document_by_id(writer, doc.pk)
 
 
-def remove_document_by_id(writer: AsyncWriter, doc_id):
+def remove_document_by_id(writer: AsyncWriter, doc_id) -> None:
     writer.delete_by_term("id", doc_id)
 
 
-def add_or_update_document(document: Document):
+def add_or_update_document(document: Document) -> None:
     with open_index_writer() as writer:
         update_document(writer, document)
 
 
-def remove_document_from_index(document: Document):
+def remove_document_from_index(document: Document) -> None:
     with open_index_writer() as writer:
         remove_document(writer, document)
 
@@ -218,11 +218,11 @@ class MappedDocIdSet(DocIdSet):
         self.document_ids = BitSet(document_ids, size=max_id)
         self.ixreader = ixreader
 
-    def __contains__(self, docnum):
+    def __contains__(self, docnum) -> bool:
         document_id = self.ixreader.stored_fields(docnum)["id"]
         return document_id in self.document_ids
 
-    def __bool__(self):
+    def __bool__(self) -> Literal[True]:
         # searcher.search ignores a filter if it's "falsy".
         # We use this hack so this DocIdSet, when used as a filter, is never ignored.
         return True
@@ -232,13 +232,13 @@ class DelayedQuery:
     def _get_query(self):
         raise NotImplementedError  # pragma: no cover
 
-    def _get_query_sortedby(self):
+    def _get_query_sortedby(self) -> tuple[None, Literal[False]] | tuple[str, bool]:
         if "ordering" not in self.query_params:
             return None, False
 
         field: str = self.query_params["ordering"]
 
-        sort_fields_map = {
+        sort_fields_map: dict[str, str] = {
             "created": "created",
             "modified": "modified",
             "added": "added",
@@ -268,7 +268,7 @@ class DelayedQuery:
         query_params,
         page_size,
         filter_queryset: QuerySet,
-    ):
+    ) -> None:
         self.searcher = searcher
         self.query_params = query_params
         self.page_size = page_size
@@ -276,7 +276,7 @@ class DelayedQuery:
         self.first_score = None
         self.filter_queryset = filter_queryset
 
-    def __len__(self):
+    def __len__(self) -> int:
         page = self[0:1]
         return len(page)
 
@@ -334,7 +334,7 @@ class LocalDateParser(English):
 
 
 class DelayedFullTextQuery(DelayedQuery):
-    def _get_query(self):
+    def _get_query(self) -> tuple:
         q_str = self.query_params["query"]
         qp = MultifieldParser(
             [
@@ -364,7 +364,7 @@ class DelayedFullTextQuery(DelayedQuery):
 
 
 class DelayedMoreLikeThisQuery(DelayedQuery):
-    def _get_query(self):
+    def _get_query(self) -> tuple:
         more_like_doc_id = int(self.query_params["more_like_id"])
         content = Document.objects.get(id=more_like_doc_id).content
 
@@ -379,7 +379,7 @@ class DelayedMoreLikeThisQuery(DelayedQuery):
         q = query.Or(
             [query.Term("content", word, boost=weight) for word, weight in kts],
         )
-        mask = {docnum}
+        mask: set = {docnum}
 
         return q, mask
 
@@ -389,7 +389,7 @@ def autocomplete(
     term: str,
     limit: int = 10,
     user: User | None = None,
-):
+) -> list:
     """
     Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions
     and without scoring
@@ -402,7 +402,7 @@ def autocomplete(
         # content field query instead and return bogus, not text data
         qp.remove_plugin_class(FieldsPlugin)
         q = qp.parse(f"{term.lower()}*")
-        user_criterias = get_permissions_criterias(user)
+        user_criterias: list = get_permissions_criterias(user)
 
         results = s.search(
             q,
@@ -417,14 +417,14 @@ def autocomplete(
                     termCounts[match] += 1
             terms = [t for t, _ in termCounts.most_common(limit)]
 
-        term_encoded = term.encode("UTF-8")
+        term_encoded: bytes = term.encode("UTF-8")
         if term_encoded in terms:
             terms.insert(0, terms.pop(terms.index(term_encoded)))
 
     return terms
 
 
-def get_permissions_criterias(user: User | None = None):
+def get_permissions_criterias(user: User | None = None) -> list:
     user_criterias = [query.Term("has_owner", False)]
     if user is not None:
         if user.is_superuser:  # superusers see all docs
index c53147d07ac64dfb964da220e30f26abc3401ecf..aa8e7d18427508c490311c0082c6fe1256fb57b0 100644 (file)
@@ -1,4 +1,4 @@
-import os
+from pathlib import Path
 
 from django.conf import settings
 from django.core.management.base import BaseCommand
@@ -14,7 +14,7 @@ class Command(BaseCommand):
         "state to an unencrypted one (or vice-versa)"
     )
 
-    def add_arguments(self, parser):
+    def add_arguments(self, parser) -> None:
         parser.add_argument(
             "--passphrase",
             help=(
@@ -23,7 +23,7 @@ class Command(BaseCommand):
             ),
         )
 
-    def handle(self, *args, **options):
+    def handle(self, *args, **options) -> None:
         try:
             self.stdout.write(
                 self.style.WARNING(
@@ -52,7 +52,7 @@ class Command(BaseCommand):
 
         self.__gpg_to_unencrypted(passphrase)
 
-    def __gpg_to_unencrypted(self, passphrase: str):
+    def __gpg_to_unencrypted(self, passphrase: str) -> None:
         encrypted_files = Document.objects.filter(
             storage_type=Document.STORAGE_TYPE_GPG,
         )
@@ -69,7 +69,7 @@ class Command(BaseCommand):
 
             document.storage_type = Document.STORAGE_TYPE_UNENCRYPTED
 
-            ext = os.path.splitext(document.filename)[1]
+            ext: str = Path(document.filename).suffix
 
             if not ext == ".gpg":
                 raise CommandError(
@@ -77,12 +77,12 @@ class Command(BaseCommand):
                     f"end with .gpg",
                 )
 
-            document.filename = os.path.splitext(document.filename)[0]
+            document.filename = Path(document.filename).stem
 
-            with open(document.source_path, "wb") as f:
+            with document.source_path.open("wb") as f:
                 f.write(raw_document)
 
-            with open(document.thumbnail_path, "wb") as f:
+            with document.thumbnail_path.open("wb") as f:
                 f.write(raw_thumb)
 
             Document.objects.filter(id=document.id).update(
@@ -91,4 +91,4 @@ class Command(BaseCommand):
             )
 
             for path in old_paths:
-                os.unlink(path)
+                path.unlink()
index f56159c811c4169486de148d3b5150e8a1865157..9e3af47e772af020951d1fc70fa476a2f6bf11c1 100644 (file)
@@ -1,6 +1,7 @@
 import json
 import logging
 import os
+from collections.abc import Generator
 from contextlib import contextmanager
 from pathlib import Path
 
@@ -44,7 +45,7 @@ if settings.AUDIT_LOG_ENABLED:
 
 
 @contextmanager
-def disable_signal(sig, receiver, sender):
+def disable_signal(sig, receiver, sender) -> Generator:
     try:
         sig.disconnect(receiver=receiver, sender=sender)
         yield
@@ -58,7 +59,7 @@ class Command(CryptMixin, BaseCommand):
         "documents it refers to."
     )
 
-    def add_arguments(self, parser):
+    def add_arguments(self, parser) -> None:
         parser.add_argument("source")
 
         parser.add_argument(
@@ -90,7 +91,7 @@ class Command(CryptMixin, BaseCommand):
         - Are there existing users or documents in the database?
         """
 
-        def pre_check_maybe_not_empty():
+        def pre_check_maybe_not_empty() -> None:
             # Skip this check if operating only on the database
             # We can expect data to exist in that case
             if not self.data_only:
@@ -122,7 +123,7 @@ class Command(CryptMixin, BaseCommand):
                     ),
                 )
 
-        def pre_check_manifest_exists():
+        def pre_check_manifest_exists() -> None:
             if not (self.source / "manifest.json").exists():
                 raise CommandError(
                     "That directory doesn't appear to contain a manifest.json file.",
@@ -141,7 +142,7 @@ class Command(CryptMixin, BaseCommand):
         """
         Loads manifest data from the various JSON files for parsing and loading the database
         """
-        main_manifest_path = self.source / "manifest.json"
+        main_manifest_path: Path = self.source / "manifest.json"
 
         with main_manifest_path.open() as infile:
             self.manifest = json.load(infile)
@@ -158,8 +159,8 @@ class Command(CryptMixin, BaseCommand):
 
         Must account for the old style of export as well, with just version.json
         """
-        version_path = self.source / "version.json"
-        metadata_path = self.source / "metadata.json"
+        version_path: Path = self.source / "version.json"
+        metadata_path: Path = self.source / "metadata.json"
         if not version_path.exists() and not metadata_path.exists():
             self.stdout.write(
                 self.style.NOTICE("No version.json or metadata.json file located"),
@@ -221,7 +222,7 @@ class Command(CryptMixin, BaseCommand):
                 )
                 raise e
 
-    def handle(self, *args, **options):
+    def handle(self, *args, **options) -> None:
         logging.getLogger().handlers[0].level = logging.ERROR
 
         self.source = Path(options["source"]).resolve()
@@ -290,13 +291,13 @@ class Command(CryptMixin, BaseCommand):
             no_progress_bar=self.no_progress_bar,
         )
 
-    def check_manifest_validity(self):
+    def check_manifest_validity(self) -> None:
         """
         Attempts to verify the manifest is valid.  Namely checking the files
         referred to exist and the files can be read from
         """
 
-        def check_document_validity(document_record: dict):
+        def check_document_validity(document_record: dict) -> None:
             if EXPORTER_FILE_NAME not in document_record:
                 raise CommandError(
                     "The manifest file contains a record which does not "
@@ -341,7 +342,7 @@ class Command(CryptMixin, BaseCommand):
             if not self.data_only and record["model"] == "documents.document":
                 check_document_validity(record)
 
-    def _import_files_from_manifest(self):
+    def _import_files_from_manifest(self) -> None:
         settings.ORIGINALS_DIR.mkdir(parents=True, exist_ok=True)
         settings.THUMBNAIL_DIR.mkdir(parents=True, exist_ok=True)
         settings.ARCHIVE_DIR.mkdir(parents=True, exist_ok=True)
@@ -356,24 +357,24 @@ class Command(CryptMixin, BaseCommand):
             document = Document.objects.get(pk=record["pk"])
 
             doc_file = record[EXPORTER_FILE_NAME]
-            document_path = os.path.join(self.source, doc_file)
+            document_path = self.source / doc_file
 
             if EXPORTER_THUMBNAIL_NAME in record:
                 thumb_file = record[EXPORTER_THUMBNAIL_NAME]
-                thumbnail_path = Path(os.path.join(self.source, thumb_file)).resolve()
+                thumbnail_path = (self.source / thumb_file).resolve()
             else:
                 thumbnail_path = None
 
             if EXPORTER_ARCHIVE_NAME in record:
                 archive_file = record[EXPORTER_ARCHIVE_NAME]
-                archive_path = os.path.join(self.source, archive_file)
+                archive_path = self.source / archive_file
             else:
                 archive_path = None
 
             document.storage_type = Document.STORAGE_TYPE_UNENCRYPTED
 
             with FileLock(settings.MEDIA_LOCK):
-                if os.path.isfile(document.source_path):
+                if Path(document.source_path).is_file():
                     raise FileExistsError(document.source_path)
 
                 create_source_path_directory(document.source_path)
@@ -418,8 +419,8 @@ class Command(CryptMixin, BaseCommand):
             had_at_least_one_record = False
 
             for crypt_config in self.CRYPT_FIELDS:
-                importer_model = crypt_config["model_name"]
-                crypt_fields = crypt_config["fields"]
+                importer_model: str = crypt_config["model_name"]
+                crypt_fields: str = crypt_config["fields"]
                 for record in filter(
                     lambda x: x["model"] == importer_model,
                     self.manifest,
index 6b4f06ec71b23ad9b7aad620c488408732bff413..a706de412c1ad59530e8a768273b95105d644886 100644 (file)
@@ -15,7 +15,7 @@ from documents.parsers import run_convert
 logger = logging.getLogger("paperless.migrations")
 
 
-def _do_convert(work_package):
+def _do_convert(work_package) -> None:
     (
         existing_encrypted_thumbnail,
         converted_encrypted_thumbnail,
@@ -30,13 +30,13 @@ def _do_convert(work_package):
         # Decrypt png
         decrypted_thumbnail = existing_encrypted_thumbnail.with_suffix("").resolve()
 
-        with open(existing_encrypted_thumbnail, "rb") as existing_encrypted_file:
+        with existing_encrypted_thumbnail.open("rb") as existing_encrypted_file:
             raw_thumb = gpg.decrypt_file(
                 existing_encrypted_file,
                 passphrase=passphrase,
                 always_trust=True,
             ).data
-            with open(decrypted_thumbnail, "wb") as decrypted_file:
+            with Path(decrypted_thumbnail).open("wb") as decrypted_file:
                 decrypted_file.write(raw_thumb)
 
         converted_decrypted_thumbnail = Path(
@@ -62,7 +62,7 @@ def _do_convert(work_package):
         )
 
         # Encrypt webp
-        with open(converted_decrypted_thumbnail, "rb") as converted_decrypted_file:
+        with Path(converted_decrypted_thumbnail).open("rb") as converted_decrypted_file:
             encrypted = gpg.encrypt_file(
                 fileobj_or_path=converted_decrypted_file,
                 recipients=None,
@@ -71,7 +71,9 @@ def _do_convert(work_package):
                 always_trust=True,
             ).data
 
-            with open(converted_encrypted_thumbnail, "wb") as converted_encrypted_file:
+            with Path(converted_encrypted_thumbnail).open(
+                "wb",
+            ) as converted_encrypted_file:
                 converted_encrypted_file.write(encrypted)
 
         # Copy newly created thumbnail to thumbnail directory
@@ -95,8 +97,8 @@ def _do_convert(work_package):
         logger.error(f"Error converting thumbnail (existing file unchanged): {e}")
 
 
-def _convert_encrypted_thumbnails_to_webp(apps, schema_editor):
-    start = time.time()
+def _convert_encrypted_thumbnails_to_webp(apps, schema_editor) -> None:
+    start: float = time.time()
 
     with tempfile.TemporaryDirectory() as tempdir:
         work_packages = []
@@ -111,15 +113,15 @@ def _convert_encrypted_thumbnails_to_webp(apps, schema_editor):
                 )
 
             for file in Path(settings.THUMBNAIL_DIR).glob("*.png.gpg"):
-                existing_thumbnail = file.resolve()
+                existing_thumbnail: Path = file.resolve()
 
                 # Change the existing filename suffix from png to webp
-                converted_thumbnail_name = Path(
+                converted_thumbnail_name: str = Path(
                     str(existing_thumbnail).replace(".png.gpg", ".webp.gpg"),
                 ).name
 
                 # Create the expected output filename in the tempdir
-                converted_thumbnail = (
+                converted_thumbnail: Path = (
                     Path(tempdir) / Path(converted_thumbnail_name)
                 ).resolve()
 
@@ -143,8 +145,8 @@ def _convert_encrypted_thumbnails_to_webp(apps, schema_editor):
                 ) as pool:
                     pool.map(_do_convert, work_packages)
 
-                    end = time.time()
-                    duration = end - start
+                    end: float = time.time()
+                    duration: float = end - start
 
                 logger.info(f"Conversion completed in {duration:.3f}s")
 
index 5e1c5a0bda7839709a4e769be9c73a1f6444cb8e..89bc5ef8c5c75b3b5588cad10e55db50580d7504 100644 (file)
@@ -173,7 +173,7 @@ class TestSystemStatus(APITestCase):
         self.assertEqual(response.data["tasks"]["index_status"], "OK")
         self.assertIsNotNone(response.data["tasks"]["index_last_modified"])
 
-    @override_settings(INDEX_DIR="/tmp/index/")
+    @override_settings(INDEX_DIR=Path("/tmp/index/"))
     @mock.patch("documents.index.open_index", autospec=True)
     def test_system_status_index_error(self, mock_open_index):
         """
@@ -193,7 +193,7 @@ class TestSystemStatus(APITestCase):
         self.assertEqual(response.data["tasks"]["index_status"], "ERROR")
         self.assertIsNotNone(response.data["tasks"]["index_error"])
 
-    @override_settings(DATA_DIR="/tmp/does_not_exist/data/")
+    @override_settings(DATA_DIR=Path("/tmp/does_not_exist/data/"))
     def test_system_status_classifier_ok(self):
         """
         GIVEN:
@@ -222,7 +222,7 @@ class TestSystemStatus(APITestCase):
         THEN:
             - The response contains an WARNING classifier status
         """
-        with override_settings(MODEL_FILE="does_not_exist"):
+        with override_settings(MODEL_FILE=Path("does_not_exist")):
             Document.objects.create(
                 title="Test Document",
             )
@@ -233,7 +233,11 @@ class TestSystemStatus(APITestCase):
             self.assertEqual(response.data["tasks"]["classifier_status"], "WARNING")
             self.assertIsNotNone(response.data["tasks"]["classifier_error"])
 
-    def test_system_status_classifier_error(self):
+    @mock.patch(
+        "documents.classifier.load_classifier",
+        side_effect=ClassifierModelCorruptError(),
+    )
+    def test_system_status_classifier_error(self, mock_load_classifier):
         """
         GIVEN:
             - The classifier does exist but is corrupt
@@ -248,25 +252,23 @@ class TestSystemStatus(APITestCase):
                 dir="/tmp",
                 delete=False,
             ) as does_exist,
-            override_settings(MODEL_FILE=does_exist),
+            override_settings(MODEL_FILE=Path(does_exist.name)),
         ):
-            with mock.patch("documents.classifier.load_classifier") as mock_load:
-                mock_load.side_effect = ClassifierModelCorruptError()
-                Document.objects.create(
-                    title="Test Document",
-                )
-                Tag.objects.create(
-                    name="Test Tag",
-                    matching_algorithm=Tag.MATCH_AUTO,
-                )
-                self.client.force_login(self.user)
-                response = self.client.get(self.ENDPOINT)
-                self.assertEqual(response.status_code, status.HTTP_200_OK)
-                self.assertEqual(
-                    response.data["tasks"]["classifier_status"],
-                    "ERROR",
-                )
-                self.assertIsNotNone(response.data["tasks"]["classifier_error"])
+            Document.objects.create(
+                title="Test Document",
+            )
+            Tag.objects.create(
+                name="Test Tag",
+                matching_algorithm=Tag.MATCH_AUTO,
+            )
+            self.client.force_login(self.user)
+            response = self.client.get(self.ENDPOINT)
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+            self.assertEqual(
+                response.data["tasks"]["classifier_status"],
+                "ERROR",
+            )
+            self.assertIsNotNone(response.data["tasks"]["classifier_error"])
 
     def test_system_status_classifier_ok_no_objects(self):
         """
@@ -278,7 +280,7 @@ class TestSystemStatus(APITestCase):
         THEN:
             - The response contains an OK classifier status
         """
-        with override_settings(MODEL_FILE="does_not_exist"):
+        with override_settings(MODEL_FILE=Path("does_not_exist")):
             self.client.force_login(self.user)
             response = self.client.get(self.ENDPOINT)
             self.assertEqual(response.status_code, status.HTTP_200_OK)
index cb1c5c8a3df95ae35d6e7fc7ba04e9bf6521d637..f90a880505f82ddb0515ac84b5ea7e2e856d9f16 100644 (file)
@@ -650,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
         Path(settings.MODEL_FILE).touch()
         self.assertTrue(os.path.exists(settings.MODEL_FILE))
 
-        load.side_effect = IncompatibleClassifierVersionError("Dummey Error")
+        load.side_effect = IncompatibleClassifierVersionError("Dummy Error")
         self.assertIsNone(load_classifier())
         self.assertFalse(os.path.exists(settings.MODEL_FILE))
 
@@ -673,3 +673,25 @@ class TestClassifier(DirectoriesMixin, TestCase):
         ):
             classifier = load_classifier()
             self.assertIsNone(classifier)
+
+    @mock.patch("documents.classifier.DocumentClassifier.load")
+    def test_load_classifier_raise_exception(self, mock_load):
+        Path(settings.MODEL_FILE).touch()
+        mock_load.side_effect = IncompatibleClassifierVersionError("Dummy Error")
+        with self.assertRaises(IncompatibleClassifierVersionError):
+            load_classifier(raise_exception=True)
+
+        Path(settings.MODEL_FILE).touch()
+        mock_load.side_effect = ClassifierModelCorruptError()
+        with self.assertRaises(ClassifierModelCorruptError):
+            load_classifier(raise_exception=True)
+
+        Path(settings.MODEL_FILE).touch()
+        mock_load.side_effect = OSError()
+        with self.assertRaises(OSError):
+            load_classifier(raise_exception=True)
+
+        Path(settings.MODEL_FILE).touch()
+        mock_load.side_effect = Exception()
+        with self.assertRaises(Exception):
+            load_classifier(raise_exception=True)
index 2f21627a71f656f15f54dc6753bb906bf7e756b1..2726fd02e2e80c8c0aad66b7689bdc574a949527 100644 (file)
@@ -108,18 +108,18 @@ class TestArchiver(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
 class TestDecryptDocuments(FileSystemAssertsMixin, TestCase):
     @override_settings(
-        ORIGINALS_DIR=os.path.join(os.path.dirname(__file__), "samples", "originals"),
-        THUMBNAIL_DIR=os.path.join(os.path.dirname(__file__), "samples", "thumb"),
+        ORIGINALS_DIR=(Path(__file__).parent / "samples" / "originals"),
+        THUMBNAIL_DIR=(Path(__file__).parent / "samples" / "thumb"),
         PASSPHRASE="test",
         FILENAME_FORMAT=None,
     )
     @mock.patch("documents.management.commands.decrypt_documents.input")
     def test_decrypt(self, m):
         media_dir = tempfile.mkdtemp()
-        originals_dir = os.path.join(media_dir, "documents", "originals")
-        thumb_dir = os.path.join(media_dir, "documents", "thumbnails")
-        os.makedirs(originals_dir, exist_ok=True)
-        os.makedirs(thumb_dir, exist_ok=True)
+        originals_dir = Path(media_dir) / "documents" / "originals"
+        thumb_dir = Path(media_dir) / "documents" / "thumbnails"
+        originals_dir.mkdir(parents=True, exist_ok=True)
+        thumb_dir.mkdir(parents=True, exist_ok=True)
 
         override_settings(
             ORIGINALS_DIR=originals_dir,
@@ -143,7 +143,7 @@ class TestDecryptDocuments(FileSystemAssertsMixin, TestCase):
                 "originals",
                 "0000004.pdf.gpg",
             ),
-            os.path.join(originals_dir, "0000004.pdf.gpg"),
+            originals_dir / "0000004.pdf.gpg",
         )
         shutil.copy(
             os.path.join(
@@ -153,7 +153,7 @@ class TestDecryptDocuments(FileSystemAssertsMixin, TestCase):
                 "thumbnails",
                 "0000004.webp.gpg",
             ),
-            os.path.join(thumb_dir, f"{doc.id:07}.webp.gpg"),
+            thumb_dir / f"{doc.id:07}.webp.gpg",
         )
 
         call_command("decrypt_documents")
index fe149798be918d8df30a831293ae459c331e929e..be2343b90b05a04d785ecab8cfd6e2078c4036f5 100644 (file)
@@ -2139,7 +2139,7 @@ class SystemStatusView(PassUserMixin):
         classifier_error = None
         classifier_status = None
         try:
-            classifier = load_classifier()
+            classifier = load_classifier(raise_exception=True)
             if classifier is None:
                 # Make sure classifier should exist
                 docs_queryset = Document.objects.exclude(
@@ -2159,7 +2159,7 @@ class SystemStatusView(PassUserMixin):
                             matching_algorithm=Tag.MATCH_AUTO,
                         ).exists()
                     )
-                    and not os.path.isfile(settings.MODEL_FILE)
+                    and not settings.MODEL_FILE.exists()
                 ):
                     # if classifier file doesn't exist just classify as a warning
                     classifier_error = "Classifier file does not exist (yet). Re-training may be pending."