[lint.per-file-ignores]
".github/scripts/*.py" = ["E501", "INP001", "SIM117"]
"docker/wait-for-redis.py" = ["INP001", "T201"]
-"src/documents/barcodes.py" = ["PTH"] # TODO Enable & remove
-"src/documents/classifier.py" = ["PTH"] # TODO Enable & remove
"src/documents/consumer.py" = ["PTH"] # TODO Enable & remove
"src/documents/file_handling.py" = ["PTH"] # TODO Enable & remove
-"src/documents/index.py" = ["PTH"] # TODO Enable & remove
-"src/documents/management/commands/decrypt_documents.py" = ["PTH"] # TODO Enable & remove
"src/documents/management/commands/document_consumer.py" = ["PTH"] # TODO Enable & remove
"src/documents/management/commands/document_exporter.py" = ["PTH"] # TODO Enable & remove
-"src/documents/management/commands/document_importer.py" = ["PTH"] # TODO Enable & remove
"src/documents/migrations/0012_auto_20160305_0040.py" = ["PTH"] # TODO Enable & remove
"src/documents/migrations/0014_document_checksum.py" = ["PTH"] # TODO Enable & remove
"src/documents/migrations/1003_mime_types.py" = ["PTH"] # TODO Enable & remove
"src/documents/migrations/1012_fix_archive_files.py" = ["PTH"] # TODO Enable & remove
-"src/documents/migrations/1037_webp_encrypted_thumbnail_conversion.py" = ["PTH"] # TODO Enable & remove
"src/documents/models.py" = ["SIM115", "PTH"] # TODO PTH Enable & remove
"src/documents/parsers.py" = ["PTH"] # TODO Enable & remove
"src/documents/signals/handlers.py" = ["PTH"] # TODO Enable & remove
import tempfile
from dataclasses import dataclass
from pathlib import Path
+from typing import TYPE_CHECKING
from django.conf import settings
from pdf2image import convert_from_path
from documents.utils import copy_file_with_basic_stats
from documents.utils import maybe_override_pixel_limit
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
logger = logging.getLogger("paperless.barcodes")
- Barcode support is enabled and the mime type is supported
"""
if settings.CONSUMER_BARCODE_TIFF_SUPPORT:
- supported_mimes = {"application/pdf", "image/tiff"}
+ supported_mimes: set[str] = {"application/pdf", "image/tiff"}
else:
supported_mimes = {"application/pdf"}
or settings.CONSUMER_ENABLE_TAG_BARCODE
) and self.input_doc.mime_type in supported_mimes
- def setup(self):
+ def setup(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory(
dir=self.base_tmp_dir,
prefix="barcode",
)
- self.pdf_file = self.input_doc.original_file
+ self.pdf_file: Path = self.input_doc.original_file
self._tiff_conversion_done = False
self.barcodes: list[Barcode] = []
- def run(self) -> str | None:
+ def run(self) -> None:
# Some operations may use PIL, override pixel setting if needed
maybe_override_pixel_limit()
def cleanup(self) -> None:
self.temp_dir.cleanup()
- def convert_from_tiff_to_pdf(self):
+ def convert_from_tiff_to_pdf(self) -> None:
"""
May convert a TIFF image into a PDF, if the input is a TIFF and
the TIFF has not been made into a PDF
# Choose the library for reading
if settings.CONSUMER_BARCODE_SCANNER == "PYZBAR":
- reader = self.read_barcodes_pyzbar
+ reader: Callable[[Image.Image], list[str]] = self.read_barcodes_pyzbar
logger.debug("Scanning for barcodes using PYZBAR")
else:
reader = self.read_barcodes_zxing
logger.debug(f"PDF has {num_of_pages} pages")
# Get limit from configuration
- barcode_max_pages = (
+ barcode_max_pages: int = (
num_of_pages
if settings.CONSUMER_BARCODE_MAX_PAGES == 0
else settings.CONSUMER_BARCODE_MAX_PAGES
self.detect()
# get the first barcode that starts with CONSUMER_ASN_BARCODE_PREFIX
- asn_text = next(
+ asn_text: str | None = next(
(x.value for x in self.barcodes if x.is_asn),
None,
)
return asn
@property
- def tags(self) -> list[int] | None:
+ def tags(self) -> list[int]:
"""
Search the parsed barcodes for any tags.
Returns the detected tag ids (or empty list)
"""
- tags = []
+ tags: list[int] = []
# Ensure the barcodes have been read
self.detect()
for x in self.barcodes:
- tag_texts = x.value
+ tag_texts: str = x.value
for raw in tag_texts.split(","):
try:
- tag = None
+ tag_str: str | None = None
for regex in settings.CONSUMER_TAG_BARCODE_MAPPING:
if re.match(regex, raw, flags=re.IGNORECASE):
sub = settings.CONSUMER_TAG_BARCODE_MAPPING[regex]
- tag = (
+ tag_str = (
re.sub(regex, sub, raw, flags=re.IGNORECASE)
if sub
else raw
)
break
- if tag:
+ if tag_str:
tag, _ = Tag.objects.get_or_create(
- name__iexact=tag,
- defaults={"name": tag},
+ name__iexact=tag_str,
+ defaults={"name": tag_str},
)
logger.debug(
"""
document_paths = []
- fname = self.input_doc.original_file.stem
+ fname: str = self.input_doc.original_file.stem
with Pdf.open(self.pdf_file) as input_pdf:
# Start with an empty document
current_document: list[Page] = []
logger.debug(f"Starting new document at idx {idx}")
current_document = []
documents.append(current_document)
- keep_page = pages_to_split_on[idx]
+ keep_page: bool = pages_to_split_on[idx]
if keep_page:
# Keep the page
# (new document is started by asn barcode)
logger.debug(f"pdf no:{doc_idx} has {len(dst.pages)} pages")
savepath = Path(self.temp_dir.name) / output_filename
- with open(savepath, "wb") as out:
+ with savepath.open("wb") as out:
dst.save(out)
copy_basic_file_stats(self.input_doc.original_file, savepath)
import logging
-import os
import pickle
import re
import warnings
from collections.abc import Iterator
from hashlib import sha256
+from pathlib import Path
from typing import TYPE_CHECKING
from typing import Optional
if TYPE_CHECKING:
from datetime import datetime
- from pathlib import Path
+
+ from numpy import ndarray
from django.conf import settings
from django.core.cache import cache
class IncompatibleClassifierVersionError(Exception):
def __init__(self, message: str, *args: object) -> None:
- self.message = message
+ self.message: str = message
super().__init__(*args)
pass
-def load_classifier() -> Optional["DocumentClassifier"]:
- if not os.path.isfile(settings.MODEL_FILE):
+def load_classifier(*, raise_exception: bool = False) -> Optional["DocumentClassifier"]:
+ if not settings.MODEL_FILE.is_file():
logger.debug(
"Document classification model does not exist (yet), not "
"performing automatic matching.",
except IncompatibleClassifierVersionError as e:
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
- os.unlink(settings.MODEL_FILE)
+ Path(settings.MODEL_FILE).unlink()
classifier = None
- except ClassifierModelCorruptError:
+ if raise_exception:
+ raise e
+ except ClassifierModelCorruptError as e:
# there's something wrong with the model file.
logger.exception(
"Unrecoverable error while loading document "
"classification model, deleting model file.",
)
- os.unlink(settings.MODEL_FILE)
+ Path(settings.MODEL_FILE).unlink
classifier = None
- except OSError:
+ if raise_exception:
+ raise e
+ except OSError as e:
logger.exception("IO error while loading document classification model")
classifier = None
- except Exception: # pragma: no cover
+ if raise_exception:
+ raise e
+ except Exception as e: # pragma: no cover
logger.exception("Unknown error while loading document classification model")
classifier = None
+ if raise_exception:
+ raise e
return classifier
# v9 - Changed from hashing to time/ids for re-train check
FORMAT_VERSION = 9
- def __init__(self):
+ def __init__(self) -> None:
# last time a document changed and therefore training might be required
self.last_doc_change_time: datetime | None = None
# Hash of primary keys of AUTO matching values last used in training
def load(self) -> None:
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
- with open(settings.MODEL_FILE, "rb") as f:
+ with Path(settings.MODEL_FILE).open("rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
):
raise IncompatibleClassifierVersionError("sklearn version update")
- def save(self):
+ def save(self) -> None:
target_file: Path = settings.MODEL_FILE
- target_file_temp = target_file.with_suffix(".pickle.part")
+ target_file_temp: Path = target_file.with_suffix(".pickle.part")
- with open(target_file_temp, "wb") as f:
+ with target_file_temp.open("wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.last_doc_change_time, f)
target_file_temp.rename(target_file)
- def train(self):
+ def train(self) -> bool:
# Get non-inbox documents
docs_queryset = (
Document.objects.exclude(
hasher.update(y.to_bytes(4, "little", signed=True))
labels_correspondent.append(y)
- tags = sorted(
+ tags: list[int] = sorted(
tag.pk
for tag in doc.tags.filter(
matching_algorithm=MatchingModel.MATCH_AUTO,
# union with {-1} accounts for cases where all documents have
# correspondents and types assigned, so -1 isn't part of labels_x, which
# it usually is.
- num_correspondents = len(set(labels_correspondent) | {-1}) - 1
- num_document_types = len(set(labels_document_type) | {-1}) - 1
- num_storage_paths = len(set(labels_storage_path) | {-1}) - 1
+ num_correspondents: int = len(set(labels_correspondent) | {-1}) - 1
+ num_document_types: int = len(set(labels_document_type) | {-1}) - 1
+ num_storage_paths: int = len(set(labels_storage_path) | {-1}) - 1
logger.debug(
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
min_df=0.01,
)
- data_vectorized = self.data_vectorizer.fit_transform(content_generator())
+ data_vectorized: ndarray = self.data_vectorizer.fit_transform(
+ content_generator(),
+ )
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
label[0] if len(label) == 1 else -1 for label in labels_tags
]
self.tags_binarizer = LabelBinarizer()
- labels_tags_vectorized = self.tags_binarizer.fit_transform(
+ labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
labels_tags,
).ravel()
else:
import logging
import math
-import os
from collections import Counter
from contextlib import contextmanager
from datetime import datetime
from datetime import timezone
from shutil import rmtree
+from typing import Literal
from django.conf import settings
from django.db.models import QuerySet
logger = logging.getLogger("paperless.index")
-def get_schema():
+def get_schema() -> Schema:
return Schema(
id=NUMERIC(stored=True, unique=True),
title=TEXT(sortable=True),
logger.exception("Error while opening the index, recreating.")
# create_in doesn't handle corrupted indexes very well, remove the directory entirely first
- if os.path.isdir(settings.INDEX_DIR):
+ if settings.INDEX_DIR.is_dir():
rmtree(settings.INDEX_DIR)
settings.INDEX_DIR.mkdir(parents=True, exist_ok=True)
searcher.close()
-def update_document(writer: AsyncWriter, doc: Document):
+def update_document(writer: AsyncWriter, doc: Document) -> None:
tags = ",".join([t.name for t in doc.tags.all()])
tags_ids = ",".join([str(t.id) for t in doc.tags.all()])
notes = ",".join([str(c.note) for c in Note.objects.filter(document=doc)])
custom_fields_ids = ",".join(
[str(f.field.id) for f in CustomFieldInstance.objects.filter(document=doc)],
)
- asn = doc.archive_serial_number
+ asn: int | None = doc.archive_serial_number
if asn is not None and (
asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
or asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
doc,
only_with_perms_in=["view_document"],
)
- viewer_ids = ",".join([str(u.id) for u in users_with_perms])
+ viewer_ids: str = ",".join([str(u.id) for u in users_with_perms])
writer.update_document(
id=doc.pk,
title=doc.title,
)
-def remove_document(writer: AsyncWriter, doc: Document):
+def remove_document(writer: AsyncWriter, doc: Document) -> None:
remove_document_by_id(writer, doc.pk)
-def remove_document_by_id(writer: AsyncWriter, doc_id):
+def remove_document_by_id(writer: AsyncWriter, doc_id) -> None:
writer.delete_by_term("id", doc_id)
-def add_or_update_document(document: Document):
+def add_or_update_document(document: Document) -> None:
with open_index_writer() as writer:
update_document(writer, document)
-def remove_document_from_index(document: Document):
+def remove_document_from_index(document: Document) -> None:
with open_index_writer() as writer:
remove_document(writer, document)
self.document_ids = BitSet(document_ids, size=max_id)
self.ixreader = ixreader
- def __contains__(self, docnum):
+ def __contains__(self, docnum) -> bool:
document_id = self.ixreader.stored_fields(docnum)["id"]
return document_id in self.document_ids
- def __bool__(self):
+ def __bool__(self) -> Literal[True]:
# searcher.search ignores a filter if it's "falsy".
# We use this hack so this DocIdSet, when used as a filter, is never ignored.
return True
def _get_query(self):
raise NotImplementedError # pragma: no cover
- def _get_query_sortedby(self):
+ def _get_query_sortedby(self) -> tuple[None, Literal[False]] | tuple[str, bool]:
if "ordering" not in self.query_params:
return None, False
field: str = self.query_params["ordering"]
- sort_fields_map = {
+ sort_fields_map: dict[str, str] = {
"created": "created",
"modified": "modified",
"added": "added",
query_params,
page_size,
filter_queryset: QuerySet,
- ):
+ ) -> None:
self.searcher = searcher
self.query_params = query_params
self.page_size = page_size
self.first_score = None
self.filter_queryset = filter_queryset
- def __len__(self):
+ def __len__(self) -> int:
page = self[0:1]
return len(page)
class DelayedFullTextQuery(DelayedQuery):
- def _get_query(self):
+ def _get_query(self) -> tuple:
q_str = self.query_params["query"]
qp = MultifieldParser(
[
class DelayedMoreLikeThisQuery(DelayedQuery):
- def _get_query(self):
+ def _get_query(self) -> tuple:
more_like_doc_id = int(self.query_params["more_like_id"])
content = Document.objects.get(id=more_like_doc_id).content
q = query.Or(
[query.Term("content", word, boost=weight) for word, weight in kts],
)
- mask = {docnum}
+ mask: set = {docnum}
return q, mask
term: str,
limit: int = 10,
user: User | None = None,
-):
+) -> list:
"""
Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions
and without scoring
# content field query instead and return bogus, not text data
qp.remove_plugin_class(FieldsPlugin)
q = qp.parse(f"{term.lower()}*")
- user_criterias = get_permissions_criterias(user)
+ user_criterias: list = get_permissions_criterias(user)
results = s.search(
q,
termCounts[match] += 1
terms = [t for t, _ in termCounts.most_common(limit)]
- term_encoded = term.encode("UTF-8")
+ term_encoded: bytes = term.encode("UTF-8")
if term_encoded in terms:
terms.insert(0, terms.pop(terms.index(term_encoded)))
return terms
-def get_permissions_criterias(user: User | None = None):
+def get_permissions_criterias(user: User | None = None) -> list:
user_criterias = [query.Term("has_owner", False)]
if user is not None:
if user.is_superuser: # superusers see all docs
-import os
+from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
"state to an unencrypted one (or vice-versa)"
)
- def add_arguments(self, parser):
+ def add_arguments(self, parser) -> None:
parser.add_argument(
"--passphrase",
help=(
),
)
- def handle(self, *args, **options):
+ def handle(self, *args, **options) -> None:
try:
self.stdout.write(
self.style.WARNING(
self.__gpg_to_unencrypted(passphrase)
- def __gpg_to_unencrypted(self, passphrase: str):
+ def __gpg_to_unencrypted(self, passphrase: str) -> None:
encrypted_files = Document.objects.filter(
storage_type=Document.STORAGE_TYPE_GPG,
)
document.storage_type = Document.STORAGE_TYPE_UNENCRYPTED
- ext = os.path.splitext(document.filename)[1]
+ ext: str = Path(document.filename).suffix
if not ext == ".gpg":
raise CommandError(
f"end with .gpg",
)
- document.filename = os.path.splitext(document.filename)[0]
+ document.filename = Path(document.filename).stem
- with open(document.source_path, "wb") as f:
+ with document.source_path.open("wb") as f:
f.write(raw_document)
- with open(document.thumbnail_path, "wb") as f:
+ with document.thumbnail_path.open("wb") as f:
f.write(raw_thumb)
Document.objects.filter(id=document.id).update(
)
for path in old_paths:
- os.unlink(path)
+ path.unlink()
import json
import logging
import os
+from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
@contextmanager
-def disable_signal(sig, receiver, sender):
+def disable_signal(sig, receiver, sender) -> Generator:
try:
sig.disconnect(receiver=receiver, sender=sender)
yield
"documents it refers to."
)
- def add_arguments(self, parser):
+ def add_arguments(self, parser) -> None:
parser.add_argument("source")
parser.add_argument(
- Are there existing users or documents in the database?
"""
- def pre_check_maybe_not_empty():
+ def pre_check_maybe_not_empty() -> None:
# Skip this check if operating only on the database
# We can expect data to exist in that case
if not self.data_only:
),
)
- def pre_check_manifest_exists():
+ def pre_check_manifest_exists() -> None:
if not (self.source / "manifest.json").exists():
raise CommandError(
"That directory doesn't appear to contain a manifest.json file.",
"""
Loads manifest data from the various JSON files for parsing and loading the database
"""
- main_manifest_path = self.source / "manifest.json"
+ main_manifest_path: Path = self.source / "manifest.json"
with main_manifest_path.open() as infile:
self.manifest = json.load(infile)
Must account for the old style of export as well, with just version.json
"""
- version_path = self.source / "version.json"
- metadata_path = self.source / "metadata.json"
+ version_path: Path = self.source / "version.json"
+ metadata_path: Path = self.source / "metadata.json"
if not version_path.exists() and not metadata_path.exists():
self.stdout.write(
self.style.NOTICE("No version.json or metadata.json file located"),
)
raise e
- def handle(self, *args, **options):
+ def handle(self, *args, **options) -> None:
logging.getLogger().handlers[0].level = logging.ERROR
self.source = Path(options["source"]).resolve()
no_progress_bar=self.no_progress_bar,
)
- def check_manifest_validity(self):
+ def check_manifest_validity(self) -> None:
"""
Attempts to verify the manifest is valid. Namely checking the files
referred to exist and the files can be read from
"""
- def check_document_validity(document_record: dict):
+ def check_document_validity(document_record: dict) -> None:
if EXPORTER_FILE_NAME not in document_record:
raise CommandError(
"The manifest file contains a record which does not "
if not self.data_only and record["model"] == "documents.document":
check_document_validity(record)
- def _import_files_from_manifest(self):
+ def _import_files_from_manifest(self) -> None:
settings.ORIGINALS_DIR.mkdir(parents=True, exist_ok=True)
settings.THUMBNAIL_DIR.mkdir(parents=True, exist_ok=True)
settings.ARCHIVE_DIR.mkdir(parents=True, exist_ok=True)
document = Document.objects.get(pk=record["pk"])
doc_file = record[EXPORTER_FILE_NAME]
- document_path = os.path.join(self.source, doc_file)
+ document_path = self.source / doc_file
if EXPORTER_THUMBNAIL_NAME in record:
thumb_file = record[EXPORTER_THUMBNAIL_NAME]
- thumbnail_path = Path(os.path.join(self.source, thumb_file)).resolve()
+ thumbnail_path = (self.source / thumb_file).resolve()
else:
thumbnail_path = None
if EXPORTER_ARCHIVE_NAME in record:
archive_file = record[EXPORTER_ARCHIVE_NAME]
- archive_path = os.path.join(self.source, archive_file)
+ archive_path = self.source / archive_file
else:
archive_path = None
document.storage_type = Document.STORAGE_TYPE_UNENCRYPTED
with FileLock(settings.MEDIA_LOCK):
- if os.path.isfile(document.source_path):
+ if Path(document.source_path).is_file():
raise FileExistsError(document.source_path)
create_source_path_directory(document.source_path)
had_at_least_one_record = False
for crypt_config in self.CRYPT_FIELDS:
- importer_model = crypt_config["model_name"]
- crypt_fields = crypt_config["fields"]
+ importer_model: str = crypt_config["model_name"]
+ crypt_fields: str = crypt_config["fields"]
for record in filter(
lambda x: x["model"] == importer_model,
self.manifest,
logger = logging.getLogger("paperless.migrations")
-def _do_convert(work_package):
+def _do_convert(work_package) -> None:
(
existing_encrypted_thumbnail,
converted_encrypted_thumbnail,
# Decrypt png
decrypted_thumbnail = existing_encrypted_thumbnail.with_suffix("").resolve()
- with open(existing_encrypted_thumbnail, "rb") as existing_encrypted_file:
+ with existing_encrypted_thumbnail.open("rb") as existing_encrypted_file:
raw_thumb = gpg.decrypt_file(
existing_encrypted_file,
passphrase=passphrase,
always_trust=True,
).data
- with open(decrypted_thumbnail, "wb") as decrypted_file:
+ with Path(decrypted_thumbnail).open("wb") as decrypted_file:
decrypted_file.write(raw_thumb)
converted_decrypted_thumbnail = Path(
)
# Encrypt webp
- with open(converted_decrypted_thumbnail, "rb") as converted_decrypted_file:
+ with Path(converted_decrypted_thumbnail).open("rb") as converted_decrypted_file:
encrypted = gpg.encrypt_file(
fileobj_or_path=converted_decrypted_file,
recipients=None,
always_trust=True,
).data
- with open(converted_encrypted_thumbnail, "wb") as converted_encrypted_file:
+ with Path(converted_encrypted_thumbnail).open(
+ "wb",
+ ) as converted_encrypted_file:
converted_encrypted_file.write(encrypted)
# Copy newly created thumbnail to thumbnail directory
logger.error(f"Error converting thumbnail (existing file unchanged): {e}")
-def _convert_encrypted_thumbnails_to_webp(apps, schema_editor):
- start = time.time()
+def _convert_encrypted_thumbnails_to_webp(apps, schema_editor) -> None:
+ start: float = time.time()
with tempfile.TemporaryDirectory() as tempdir:
work_packages = []
)
for file in Path(settings.THUMBNAIL_DIR).glob("*.png.gpg"):
- existing_thumbnail = file.resolve()
+ existing_thumbnail: Path = file.resolve()
# Change the existing filename suffix from png to webp
- converted_thumbnail_name = Path(
+ converted_thumbnail_name: str = Path(
str(existing_thumbnail).replace(".png.gpg", ".webp.gpg"),
).name
# Create the expected output filename in the tempdir
- converted_thumbnail = (
+ converted_thumbnail: Path = (
Path(tempdir) / Path(converted_thumbnail_name)
).resolve()
) as pool:
pool.map(_do_convert, work_packages)
- end = time.time()
- duration = end - start
+ end: float = time.time()
+ duration: float = end - start
logger.info(f"Conversion completed in {duration:.3f}s")
self.assertEqual(response.data["tasks"]["index_status"], "OK")
self.assertIsNotNone(response.data["tasks"]["index_last_modified"])
- @override_settings(INDEX_DIR="/tmp/index/")
+ @override_settings(INDEX_DIR=Path("/tmp/index/"))
@mock.patch("documents.index.open_index", autospec=True)
def test_system_status_index_error(self, mock_open_index):
"""
self.assertEqual(response.data["tasks"]["index_status"], "ERROR")
self.assertIsNotNone(response.data["tasks"]["index_error"])
- @override_settings(DATA_DIR="/tmp/does_not_exist/data/")
+ @override_settings(DATA_DIR=Path("/tmp/does_not_exist/data/"))
def test_system_status_classifier_ok(self):
"""
GIVEN:
THEN:
- The response contains an WARNING classifier status
"""
- with override_settings(MODEL_FILE="does_not_exist"):
+ with override_settings(MODEL_FILE=Path("does_not_exist")):
Document.objects.create(
title="Test Document",
)
self.assertEqual(response.data["tasks"]["classifier_status"], "WARNING")
self.assertIsNotNone(response.data["tasks"]["classifier_error"])
- def test_system_status_classifier_error(self):
+ @mock.patch(
+ "documents.classifier.load_classifier",
+ side_effect=ClassifierModelCorruptError(),
+ )
+ def test_system_status_classifier_error(self, mock_load_classifier):
"""
GIVEN:
- The classifier does exist but is corrupt
dir="/tmp",
delete=False,
) as does_exist,
- override_settings(MODEL_FILE=does_exist),
+ override_settings(MODEL_FILE=Path(does_exist.name)),
):
- with mock.patch("documents.classifier.load_classifier") as mock_load:
- mock_load.side_effect = ClassifierModelCorruptError()
- Document.objects.create(
- title="Test Document",
- )
- Tag.objects.create(
- name="Test Tag",
- matching_algorithm=Tag.MATCH_AUTO,
- )
- self.client.force_login(self.user)
- response = self.client.get(self.ENDPOINT)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(
- response.data["tasks"]["classifier_status"],
- "ERROR",
- )
- self.assertIsNotNone(response.data["tasks"]["classifier_error"])
+ Document.objects.create(
+ title="Test Document",
+ )
+ Tag.objects.create(
+ name="Test Tag",
+ matching_algorithm=Tag.MATCH_AUTO,
+ )
+ self.client.force_login(self.user)
+ response = self.client.get(self.ENDPOINT)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(
+ response.data["tasks"]["classifier_status"],
+ "ERROR",
+ )
+ self.assertIsNotNone(response.data["tasks"]["classifier_error"])
def test_system_status_classifier_ok_no_objects(self):
"""
THEN:
- The response contains an OK classifier status
"""
- with override_settings(MODEL_FILE="does_not_exist"):
+ with override_settings(MODEL_FILE=Path("does_not_exist")):
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
Path(settings.MODEL_FILE).touch()
self.assertTrue(os.path.exists(settings.MODEL_FILE))
- load.side_effect = IncompatibleClassifierVersionError("Dummey Error")
+ load.side_effect = IncompatibleClassifierVersionError("Dummy Error")
self.assertIsNone(load_classifier())
self.assertFalse(os.path.exists(settings.MODEL_FILE))
):
classifier = load_classifier()
self.assertIsNone(classifier)
+
+ @mock.patch("documents.classifier.DocumentClassifier.load")
+ def test_load_classifier_raise_exception(self, mock_load):
+ Path(settings.MODEL_FILE).touch()
+ mock_load.side_effect = IncompatibleClassifierVersionError("Dummy Error")
+ with self.assertRaises(IncompatibleClassifierVersionError):
+ load_classifier(raise_exception=True)
+
+ Path(settings.MODEL_FILE).touch()
+ mock_load.side_effect = ClassifierModelCorruptError()
+ with self.assertRaises(ClassifierModelCorruptError):
+ load_classifier(raise_exception=True)
+
+ Path(settings.MODEL_FILE).touch()
+ mock_load.side_effect = OSError()
+ with self.assertRaises(OSError):
+ load_classifier(raise_exception=True)
+
+ Path(settings.MODEL_FILE).touch()
+ mock_load.side_effect = Exception()
+ with self.assertRaises(Exception):
+ load_classifier(raise_exception=True)
class TestDecryptDocuments(FileSystemAssertsMixin, TestCase):
@override_settings(
- ORIGINALS_DIR=os.path.join(os.path.dirname(__file__), "samples", "originals"),
- THUMBNAIL_DIR=os.path.join(os.path.dirname(__file__), "samples", "thumb"),
+ ORIGINALS_DIR=(Path(__file__).parent / "samples" / "originals"),
+ THUMBNAIL_DIR=(Path(__file__).parent / "samples" / "thumb"),
PASSPHRASE="test",
FILENAME_FORMAT=None,
)
@mock.patch("documents.management.commands.decrypt_documents.input")
def test_decrypt(self, m):
media_dir = tempfile.mkdtemp()
- originals_dir = os.path.join(media_dir, "documents", "originals")
- thumb_dir = os.path.join(media_dir, "documents", "thumbnails")
- os.makedirs(originals_dir, exist_ok=True)
- os.makedirs(thumb_dir, exist_ok=True)
+ originals_dir = Path(media_dir) / "documents" / "originals"
+ thumb_dir = Path(media_dir) / "documents" / "thumbnails"
+ originals_dir.mkdir(parents=True, exist_ok=True)
+ thumb_dir.mkdir(parents=True, exist_ok=True)
override_settings(
ORIGINALS_DIR=originals_dir,
"originals",
"0000004.pdf.gpg",
),
- os.path.join(originals_dir, "0000004.pdf.gpg"),
+ originals_dir / "0000004.pdf.gpg",
)
shutil.copy(
os.path.join(
"thumbnails",
"0000004.webp.gpg",
),
- os.path.join(thumb_dir, f"{doc.id:07}.webp.gpg"),
+ thumb_dir / f"{doc.id:07}.webp.gpg",
)
call_command("decrypt_documents")
classifier_error = None
classifier_status = None
try:
- classifier = load_classifier()
+ classifier = load_classifier(raise_exception=True)
if classifier is None:
# Make sure classifier should exist
docs_queryset = Document.objects.exclude(
matching_algorithm=Tag.MATCH_AUTO,
).exists()
)
- and not os.path.isfile(settings.MODEL_FILE)
+ and not settings.MODEL_FILE.exists()
):
# if classifier file doesn't exist just classify as a warning
classifier_error = "Classifier file does not exist (yet). Re-training may be pending."