]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Creates a data model for the document consumption, allowing stronger typing of argume...
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Mon, 23 Jan 2023 23:55:49 +0000 (15:55 -0800)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Sat, 1 Apr 2023 18:05:34 +0000 (11:05 -0700)
14 files changed:
src/documents/barcodes.py
src/documents/consumer.py
src/documents/data_models.py [new file with mode: 0644]
src/documents/management/commands/document_consumer.py
src/documents/signals/handlers.py
src/documents/tasks.py
src/documents/tests/test_api.py
src/documents/tests/test_barcodes.py
src/documents/tests/test_management_consumer.py
src/documents/tests/test_task_signals.py
src/documents/tests/utils.py
src/documents/views.py
src/paperless_mail/mail.py
src/paperless_mail/tests/test_mail.py

index 1f520c546eb276029b8db10aa7c71d1d639a2be5..82b81fecd0be3b8f14a7886a46b1086df27ff102 100644 (file)
@@ -11,7 +11,6 @@ from typing import List
 from typing import Optional
 
 import img2pdf
-import magic
 from django.conf import settings
 from pdf2image import convert_from_path
 from pdf2image.exceptions import PDFPageCountError
@@ -63,7 +62,7 @@ class DocumentBarcodeInfo:
 
 
 @lru_cache(maxsize=8)
-def supported_file_type(mime_type) -> bool:
+def supported_file_type(mime_type: str) -> bool:
     """
     Determines if the file is valid for barcode
     processing, based on MIME type and settings
@@ -115,33 +114,16 @@ def barcode_reader(image: Image) -> List[str]:
     return barcodes
 
 
-def get_file_mime_type(path: Path) -> str:
-    """
-    Determines the file type, based on MIME type.
-
-    Returns the MIME type.
-    """
-    mime_type = magic.from_file(path, mime=True)
-    logger.debug(f"Detected mime type: {mime_type}")
-    return mime_type
-
-
 def convert_from_tiff_to_pdf(filepath: Path) -> Path:
     """
     converts a given TIFF image file to pdf into a temporary directory.
 
     Returns the new pdf file.
     """
-    mime_type = get_file_mime_type(filepath)
     tempdir = tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR)
     # use old file name with pdf extension
-    if mime_type == "image/tiff":
-        newpath = Path(tempdir) / Path(filepath.name).with_suffix(".pdf")
-    else:
-        logger.warning(
-            f"Cannot convert mime type {mime_type} from {filepath} to pdf.",
-        )
-        return None
+    newpath = Path(tempdir) / Path(filepath.name).with_suffix(".pdf")
+
     with Image.open(filepath) as im:
         has_alpha_layer = im.mode in ("RGBA", "LA")
     if has_alpha_layer:
@@ -162,6 +144,7 @@ def convert_from_tiff_to_pdf(filepath: Path) -> Path:
 
 def scan_file_for_barcodes(
     filepath: Path,
+    mime_type: str,
 ) -> DocumentBarcodeInfo:
     """
     Scan the provided pdf file for any barcodes
@@ -186,7 +169,6 @@ def scan_file_for_barcodes(
         return detected_barcodes
 
     pdf_filepath = None
-    mime_type = get_file_mime_type(filepath)
     barcodes = []
 
     if supported_file_type(mime_type):
index 175c80876f0ae9b7686d4b040258134afec01f98..797345ba6bf02892783b5b8357bdf37665734310 100644 (file)
@@ -284,7 +284,7 @@ class Consumer(LoggingMixin):
 
     def try_consume_file(
         self,
-        path,
+        path: Path,
         override_filename=None,
         override_title=None,
         override_correspondent_id=None,
diff --git a/src/documents/data_models.py b/src/documents/data_models.py
new file mode 100644 (file)
index 0000000..f904743
--- /dev/null
@@ -0,0 +1,62 @@
+import dataclasses
+import datetime
+import enum
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import magic
+
+
+@dataclasses.dataclass
+class DocumentMetadataOverrides:
+    """
+    Manages overrides for document fields which normally would
+    be set from content or matching.  All fields default to None,
+    meaning no override is happening
+    """
+
+    filename: Optional[str] = None
+    title: Optional[str] = None
+    correspondent_id: Optional[int] = None
+    document_type_id: Optional[int] = None
+    tag_ids: Optional[List[int]] = None
+    created: Optional[datetime.datetime] = None
+    asn: Optional[int] = None
+    owner_id: Optional[int] = None
+
+
+class DocumentSource(enum.IntEnum):
+    """
+    The source of an incoming document.  May have other uses in the future
+    """
+
+    ConsumeFolder = enum.auto()
+    ApiUpload = enum.auto()
+    MailFetch = enum.auto()
+
+
+@dataclasses.dataclass
+class ConsumableDocument:
+    """
+    Encapsulates an incoming document, either from consume folder, API upload
+    or mail fetching and certain useful operations on it.
+    """
+
+    source: DocumentSource
+    original_file: Path
+    mime_type: str = dataclasses.field(init=False, default=None)
+
+    def __post_init__(self):
+        """
+        After a dataclass is initialized, this is called to finalize some data
+        1. Make sure the original path is an absolute, fully qualified path
+        2. Get the mime type of the file
+        """
+        # Always fully qualify the path first thing
+        # Just in case, convert to a path if it's a str
+        self.original_file = Path(self.original_file).resolve()
+
+        # Get the file type once at init
+        # Note this function isn't called when the object is unpickled
+        self.mime_type = magic.from_file(self.original_file, mime=True)
index d4ace3f1bf402f6a515490cd2e29945aee27c783..27749ea7c74232902fb24e163c8c829948bde26d 100644 (file)
@@ -13,6 +13,9 @@ from typing import Set
 from django.conf import settings
 from django.core.management.base import BaseCommand
 from django.core.management.base import CommandError
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.models import Tag
 from documents.parsers import is_file_ext_supported
 from documents.tasks import consume_file
@@ -122,8 +125,11 @@ def _consume(filepath: str) -> None:
     try:
         logger.info(f"Adding {filepath} to the task queue.")
         consume_file.delay(
-            filepath,
-            override_tag_ids=list(tag_ids) if tag_ids else None,
+            ConsumableDocument(
+                source=DocumentSource.ConsumeFolder,
+                original_file=filepath,
+            ),
+            DocumentMetadataOverrides(tag_ids=tag_ids),
         )
     except Exception:
         # Catch all so that the consumer won't crash.
index 670ceae6467db3243a9564722efcc39c46d0e279..92f8e61597323f11e360340a995e85e0f616d7c4 100644 (file)
@@ -1,7 +1,6 @@
 import logging
 import os
 import shutil
-from pathlib import Path
 
 from celery import states
 from celery.signals import before_task_publish
@@ -533,17 +532,9 @@ def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs):
 
     try:
         task_args = body[0]
-        task_kwargs = body[1]
+        input_doc, _ = task_args
 
-        task_file_name = ""
-        if "override_filename" in task_kwargs:
-            task_file_name = task_kwargs["override_filename"]
-
-        # Nothing was found, report the task first argument
-        if not len(task_file_name):
-            # There are always some arguments to the consume, first is always filename
-            filepath = Path(task_args[0])
-            task_file_name = filepath.name
+        task_file_name = input_doc.original_file.name
 
         PaperlessTask.objects.create(
             task_id=headers["id"],
index fbc754e52fb962fecf2754de1992de15b26a52cc..5c300bca288260eb11dca82bd5cd4e83a1a7e77b 100644 (file)
@@ -1,13 +1,10 @@
 import hashlib
 import logging
-import os
 import shutil
 import uuid
-from pathlib import Path
 from typing import Optional
 from typing import Type
 
-import dateutil.parser
 import tqdm
 from asgiref.sync import async_to_sync
 from celery import shared_task
@@ -22,6 +19,9 @@ from documents.classifier import DocumentClassifier
 from documents.classifier import load_classifier
 from documents.consumer import Consumer
 from documents.consumer import ConsumerError
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.file_handling import create_source_path_directory
 from documents.file_handling import generate_unique_filename
 from documents.models import Correspondent
@@ -88,34 +88,20 @@ def train_classifier():
 
 @shared_task
 def consume_file(
-    path,
-    override_filename=None,
-    override_title=None,
-    override_correspondent_id=None,
-    override_document_type_id=None,
-    override_tag_ids=None,
-    task_id=None,
-    override_created=None,
-    override_owner_id=None,
-    override_archive_serial_num: Optional[int] = None,
+    input_doc: ConsumableDocument,
+    overrides: Optional[DocumentMetadataOverrides] = None,
 ):
 
-    path = Path(path).resolve()
-    asn = None
-
-    # Celery converts this to a string, but everything expects a datetime
-    # Long term solution is to not use JSON for the serializer but pickle instead
-    # TODO: This will be resolved in kombu 5.3, expected with celery 5.3
-    # More types will be retained through JSON encode/decode
-    if override_created is not None and isinstance(override_created, str):
-        try:
-            override_created = dateutil.parser.isoparse(override_created)
-        except Exception:
-            pass
+    # Default no overrides
+    if overrides is None:
+        overrides = DocumentMetadataOverrides()
 
     # read all barcodes in the current document
     if settings.CONSUMER_ENABLE_BARCODES or settings.CONSUMER_ENABLE_ASN_BARCODE:
-        doc_barcode_info = barcodes.scan_file_for_barcodes(path)
+        doc_barcode_info = barcodes.scan_file_for_barcodes(
+            input_doc.original_file,
+            input_doc.mime_type,
+        )
 
         # split document by separator pages, if enabled
         if settings.CONSUMER_ENABLE_BARCODES:
@@ -123,7 +109,7 @@ def consume_file(
 
             if len(separators) > 0:
                 logger.debug(
-                    f"Pages with separators found in: {str(path)}",
+                    f"Pages with separators found in: {input_doc.original_file}",
                 )
                 document_list = barcodes.separate_pages(
                     doc_barcode_info.pdf_path,
@@ -136,18 +122,20 @@ def consume_file(
                     # Move it to consume directory to be picked up
                     # Otherwise, use the current parent to keep possible tags
                     # from subdirectories
-                    try:
-                        # is_relative_to would be nicer, but new in 3.9
-                        _ = path.relative_to(settings.SCRATCH_DIR)
+                    if input_doc.source != DocumentSource.ConsumeFolder:
                         save_to_dir = settings.CONSUMPTION_DIR
-                    except ValueError:
-                        save_to_dir = path.parent
+                    else:
+                        # Note this uses the original file, because it's in the
+                        # consume folder already and may include additional path
+                        # components for tagging
+                        # the .path is somewhere in scratch in this case
+                        save_to_dir = input_doc.original_file.parent
 
                     for n, document in enumerate(document_list):
                         # save to consumption dir
                         # rename it to the original filename  with number prefix
-                        if override_filename:
-                            newname = f"{str(n)}_" + override_filename
+                        if overrides.filename is not None:
+                            newname = f"{str(n)}_{overrides.filename}"
                         else:
                             newname = None
 
@@ -158,24 +146,27 @@ def consume_file(
                         )
 
                         # Split file has been copied safely, remove it
-                        os.remove(document)
+                        document.unlink()
 
                     # And clean up the directory as well, now it's empty
-                    shutil.rmtree(os.path.dirname(document_list[0]))
+                    shutil.rmtree(document_list[0].parent)
 
-                    # Delete the PDF file which was split
-                    os.remove(doc_barcode_info.pdf_path)
+                    # This file has been split into multiple files without issue
+                    # remove the original and working copy
+                    input_doc.original_file.unlink()
 
-                    # If the original was a TIFF, remove the original file as well
-                    if str(doc_barcode_info.pdf_path) != str(path):
-                        logger.debug(f"Deleting file {path}")
-                        os.unlink(path)
+                    # If the original file was a TIFF, remove the PDF generated from it
+                    if input_doc.mime_type == "image/tiff":
+                        logger.debug(
+                            f"Deleting file {doc_barcode_info.pdf_path}",
+                        )
+                        doc_barcode_info.pdf_path.unlink()
 
                     # notify the sender, otherwise the progress bar
                     # in the UI stays stuck
                     payload = {
-                        "filename": override_filename or path.name,
-                        "task_id": task_id,
+                        "filename": overrides.filename or input_doc.original_file.name,
+                        "task_id": None,
                         "current_progress": 100,
                         "max_progress": 100,
                         "status": "SUCCESS",
@@ -194,22 +185,21 @@ def consume_file(
 
         # try reading the ASN from barcode
         if settings.CONSUMER_ENABLE_ASN_BARCODE:
-            asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
-            if asn:
-                logger.info(f"Found ASN in barcode: {asn}")
+            overrides.asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
+            if overrides.asn:
+                logger.info(f"Found ASN in barcode: {overrides.asn}")
 
     # continue with consumption if no barcode was found
     document = Consumer().try_consume_file(
-        path,
-        override_filename=override_filename,
-        override_title=override_title,
-        override_correspondent_id=override_correspondent_id,
-        override_document_type_id=override_document_type_id,
-        override_tag_ids=override_tag_ids,
-        task_id=task_id,
-        override_created=override_created,
-        override_asn=override_archive_serial_num or asn,
-        override_owner_id=override_owner_id,
+        input_doc.original_file,
+        override_filename=overrides.filename,
+        override_title=overrides.title,
+        override_correspondent_id=overrides.correspondent_id,
+        override_document_type_id=overrides.document_type_id,
+        override_tag_ids=overrides.tag_ids,
+        override_created=overrides.created,
+        override_asn=overrides.asn,
+        override_owner_id=overrides.owner_id,
     )
 
     if document:
index 958f0b3fe5950e6beb587decf4dc734ce22f5d74..da60ab3c4f90ee0d24d32b6dfc8af0a703b79da4 100644 (file)
@@ -32,6 +32,7 @@ from documents import bulk_edit
 from documents import index
 from documents.models import Correspondent
 from documents.models import Document
+from documents.tests.utils import DocumentConsumeDelayMixin
 from documents.models import DocumentType
 from documents.models import MatchingModel
 from documents.models import PaperlessTask
@@ -45,7 +46,7 @@ from rest_framework.test import APITestCase
 from whoosh.writing import AsyncWriter
 
 
-class TestDocumentApi(DirectoriesMixin, APITestCase):
+class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
     def setUp(self):
         super().setUp()
 
@@ -1085,10 +1086,11 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
         self.assertEqual(response.data["documents_inbox"], None)
         self.assertEqual(response.data["inbox_tag"], None)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload(self, m):
+    def test_upload(self):
 
-        m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1101,21 +1103,22 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
 
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        m.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = m.call_args
-        file_path = Path(args[0])
-        self.assertEqual(file_path.name, "simple.pdf")
-        self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents)
-        self.assertIsNone(kwargs["override_title"])
-        self.assertIsNone(kwargs["override_correspondent_id"])
-        self.assertIsNone(kwargs["override_document_type_id"])
-        self.assertIsNone(kwargs["override_tag_ids"])
+        input_doc, overrides = self.get_last_consume_delay_call_args()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_empty_metadata(self, m):
+        self.assertEqual(input_doc.original_file.name, "simple.pdf")
+        self.assertIn(Path(settings.SCRATCH_DIR), input_doc.original_file.parents)
+        self.assertIsNone(overrides.title)
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.tag_ids)
 
-        m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+    def test_upload_empty_metadata(self):
+
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1128,21 +1131,22 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
 
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        m.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = m.call_args
-        file_path = Path(args[0])
-        self.assertEqual(file_path.name, "simple.pdf")
-        self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents)
-        self.assertIsNone(kwargs["override_title"])
-        self.assertIsNone(kwargs["override_correspondent_id"])
-        self.assertIsNone(kwargs["override_document_type_id"])
-        self.assertIsNone(kwargs["override_tag_ids"])
+        input_doc, overrides = self.get_last_consume_delay_call_args()
+
+        self.assertEqual(input_doc.original_file.name, "simple.pdf")
+        self.assertIn(Path(settings.SCRATCH_DIR), input_doc.original_file.parents)
+        self.assertIsNone(overrides.title)
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.tag_ids)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_invalid_form(self, m):
+    def test_upload_invalid_form(self):
 
-        m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1153,12 +1157,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
                 {"documenst": f},
             )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-        m.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_invalid_file(self, m):
+    def test_upload_invalid_file(self):
 
-        m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.zip"),
@@ -1169,12 +1174,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
                 {"document": f},
             )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-        m.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_title(self, async_task):
+    def test_upload_with_title(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1186,16 +1192,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        async_task.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = async_task.call_args
+        _, overrides = self.get_last_consume_delay_call_args()
 
-        self.assertEqual(kwargs["override_title"], "my custom title")
+        self.assertEqual(overrides.title, "my custom title")
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.tag_ids)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_correspondent(self, async_task):
+    def test_upload_with_correspondent(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         c = Correspondent.objects.create(name="test-corres")
         with open(
@@ -1208,16 +1218,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        async_task.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = async_task.call_args
+        _, overrides = self.get_last_consume_delay_call_args()
 
-        self.assertEqual(kwargs["override_correspondent_id"], c.id)
+        self.assertEqual(overrides.correspondent_id, c.id)
+        self.assertIsNone(overrides.title)
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.tag_ids)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_invalid_correspondent(self, async_task):
+    def test_upload_with_invalid_correspondent(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1229,12 +1243,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
-        async_task.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_document_type(self, async_task):
+    def test_upload_with_document_type(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         dt = DocumentType.objects.create(name="invoice")
         with open(
@@ -1247,16 +1262,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        async_task.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = async_task.call_args
+        _, overrides = self.get_last_consume_delay_call_args()
 
-        self.assertEqual(kwargs["override_document_type_id"], dt.id)
+        self.assertEqual(overrides.document_type_id, dt.id)
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.title)
+        self.assertIsNone(overrides.tag_ids)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_invalid_document_type(self, async_task):
+    def test_upload_with_invalid_document_type(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1268,12 +1287,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
-        async_task.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_tags(self, async_task):
+    def test_upload_with_tags(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         t1 = Tag.objects.create(name="tag1")
         t2 = Tag.objects.create(name="tag2")
@@ -1287,16 +1307,20 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        async_task.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = async_task.call_args
+        _, overrides = self.get_last_consume_delay_call_args()
 
-        self.assertCountEqual(kwargs["override_tag_ids"], [t1.id, t2.id])
+        self.assertCountEqual(overrides.tag_ids, [t1.id, t2.id])
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.title)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_invalid_tags(self, async_task):
+    def test_upload_with_invalid_tags(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         t1 = Tag.objects.create(name="tag1")
         t2 = Tag.objects.create(name="tag2")
@@ -1310,12 +1334,13 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
-        async_task.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_created(self, async_task):
+    def test_upload_with_created(self):
 
-        async_task.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         created = datetime.datetime(
             2022,
@@ -1337,16 +1362,17 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
             )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        async_task.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = async_task.call_args
+        _, overrides = self.get_last_consume_delay_call_args()
 
-        self.assertEqual(kwargs["override_created"], created)
+        self.assertEqual(overrides.created, created)
 
-    @mock.patch("documents.views.consume_file.delay")
-    def test_upload_with_asn(self, m):
+    def test_upload_with_asn(self):
 
-        m.return_value = celery.result.AsyncResult(id=str(uuid.uuid4()))
+        self.consume_file_mock.return_value = celery.result.AsyncResult(
+            id=str(uuid.uuid4()),
+        )
 
         with open(
             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
@@ -1359,17 +1385,16 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
 
         self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        m.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = m.call_args
-        file_path = Path(args[0])
-        self.assertEqual(file_path.name, "simple.pdf")
-        self.assertIn(Path(settings.SCRATCH_DIR), file_path.parents)
-        self.assertIsNone(kwargs["override_title"])
-        self.assertIsNone(kwargs["override_correspondent_id"])
-        self.assertIsNone(kwargs["override_document_type_id"])
-        self.assertIsNone(kwargs["override_tag_ids"])
-        self.assertEqual(500, kwargs["override_archive_serial_num"])
+        input_doc, overrides = self.get_last_consume_delay_call_args()
+
+        self.assertEqual(input_doc.original_file.name, "simple.pdf")
+        self.assertEqual(overrides.filename, "simple.pdf")
+        self.assertIsNone(overrides.correspondent_id)
+        self.assertIsNone(overrides.document_type_id)
+        self.assertIsNone(overrides.tag_ids)
+        self.assertEqual(500, overrides.asn)
 
     def test_get_metadata(self):
         doc = Document.objects.create(
index a1e08c5cfafa6f0518681ac23a94a82acd377c5f..975a3cc1bb464599c942973e21a86a8a475fa423 100644 (file)
@@ -10,6 +10,9 @@ from django.test import TestCase
 from documents import barcodes
 from documents import tasks
 from documents.consumer import ConsumerError
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.tests.utils import DirectoriesMixin
 from documents.tests.utils import FileSystemAssertsMixin
 from PIL import Image
@@ -183,46 +186,14 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         img = Image.open(test_file)
         self.assertEqual(barcodes.barcode_reader(img), ["CUSTOM BARCODE"])
 
-    def test_get_mime_type(self):
-        """
-        GIVEN:
-            -
-        WHEN:
-            -
-        THEN:
-            -
-        """
-        tiff_file = self.SAMPLE_DIR / "simple.tiff"
-
-        pdf_file = self.SAMPLE_DIR / "simple.pdf"
-
-        png_file = self.BARCODE_SAMPLE_DIR / "barcode-128-custom.png"
-
-        tiff_file_no_extension = settings.SCRATCH_DIR / "testfile1"
-        pdf_file_no_extension = settings.SCRATCH_DIR / "testfile2"
-        shutil.copy(tiff_file, tiff_file_no_extension)
-        shutil.copy(pdf_file, pdf_file_no_extension)
-
-        self.assertEqual(barcodes.get_file_mime_type(tiff_file), "image/tiff")
-        self.assertEqual(barcodes.get_file_mime_type(pdf_file), "application/pdf")
-        self.assertEqual(
-            barcodes.get_file_mime_type(tiff_file_no_extension),
-            "image/tiff",
-        )
-        self.assertEqual(
-            barcodes.get_file_mime_type(pdf_file_no_extension),
-            "application/pdf",
-        )
-        self.assertEqual(barcodes.get_file_mime_type(png_file), "image/png")
-
     def test_convert_from_tiff_to_pdf(self):
         """
         GIVEN:
-            -
+            - Multi-page TIFF image
         WHEN:
-            -
+            - Conversion to PDF
         THEN:
-            -
+            - The file converts without error
         """
         test_file = self.SAMPLE_DIR / "simple.tiff"
 
@@ -233,34 +204,20 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.assertIsFile(target_file)
         self.assertEqual(target_file.suffix, ".pdf")
 
-    def test_convert_error_from_pdf_to_pdf(self):
-        """
-        GIVEN:
-            -
-        WHEN:
-            -
-        THEN:
-            -
-        """
-        test_file = self.SAMPLE_DIR / "simple.pdf"
-
-        dst = settings.SCRATCH_DIR / "simple.pdf"
-        shutil.copy(test_file, dst)
-        self.assertIsNone(barcodes.convert_from_tiff_to_pdf(dst))
-
     def test_scan_file_for_separating_barcodes(self):
         """
         GIVEN:
-            -
+            - PDF containing barcodes
         WHEN:
-            -
+            - File is scanned for barcodes
         THEN:
-            -
+            - Correct page index located
         """
         test_file = self.BARCODE_SAMPLE_DIR / "patch-code-t.pdf"
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -272,15 +229,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
     def test_scan_file_for_separating_barcodes_none_present(self):
         """
         GIVEN:
-            -
+            - File with no barcodes
         WHEN:
-            -
+            - File is scanned
         THEN:
-            -
+            - No barcodes detected
+            - No pages to split on
         """
         test_file = self.SAMPLE_DIR / "simple.pdf"
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -302,6 +261,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -323,6 +283,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -345,6 +306,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -366,6 +328,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -388,6 +351,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -411,6 +375,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -435,6 +400,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -459,6 +425,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -482,6 +449,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -504,6 +472,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -636,6 +605,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -673,7 +643,16 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         shutil.copy(test_file, dst)
 
         with mock.patch("documents.tasks.async_to_sync"):
-            self.assertEqual(tasks.consume_file(dst), "File successfully split")
+            self.assertEqual(
+                tasks.consume_file(
+                    ConsumableDocument(
+                        source=DocumentSource.ConsumeFolder,
+                        original_file=dst,
+                    ),
+                    None,
+                ),
+                "File successfully split",
+            )
 
     @override_settings(
         CONSUMER_ENABLE_BARCODES=True,
@@ -694,7 +673,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         shutil.copy(test_file, dst)
 
         with mock.patch("documents.tasks.async_to_sync"):
-            self.assertEqual(tasks.consume_file(dst), "File successfully split")
+            self.assertEqual(
+                tasks.consume_file(
+                    ConsumableDocument(
+                        source=DocumentSource.ConsumeFolder,
+                        original_file=dst,
+                    ),
+                    None,
+                ),
+                "File successfully split",
+            )
+        self.assertFalse(dst.exists())
 
     @override_settings(
         CONSUMER_ENABLE_BARCODES=True,
@@ -717,7 +706,16 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         shutil.copy(test_file, dst)
 
         with self.assertLogs("paperless.barcodes", level="WARNING") as cm:
-            self.assertIn("Success", tasks.consume_file(dst))
+            self.assertIn(
+                "Success",
+                tasks.consume_file(
+                    ConsumableDocument(
+                        source=DocumentSource.ConsumeFolder,
+                        original_file=dst,
+                    ),
+                    None,
+                ),
+            )
 
         self.assertListEqual(
             cm.output,
@@ -754,7 +752,17 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         shutil.copy(test_file, dst)
 
         with mock.patch("documents.tasks.async_to_sync"):
-            self.assertEqual(tasks.consume_file(dst), "File successfully split")
+            self.assertEqual(
+                tasks.consume_file(
+                    ConsumableDocument(
+                        source=DocumentSource.ConsumeFolder,
+                        original_file=dst,
+                    ),
+                    None,
+                ),
+                "File successfully split",
+            )
+        self.assertFalse(dst.exists())
 
     def test_scan_file_for_separating_barcodes_password(self):
         """
@@ -769,6 +777,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         with self.assertLogs("paperless.barcodes", level="WARNING") as cm:
             doc_barcode_info = barcodes.scan_file_for_barcodes(
                 test_file,
+                "application/pdf",
             )
             warning = cm.output[0]
             expected_str = "WARNING:paperless.barcodes:File is likely password protected, not checking for barcodes"
@@ -798,6 +807,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -835,6 +845,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         separator_page_numbers = barcodes.get_separating_barcodes(
             doc_barcode_info.barcodes,
@@ -855,7 +866,7 @@ class TestBarcode(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.assertEqual(len(document_list), 5)
 
 
-class TestAsnBarcodes(DirectoriesMixin, TestCase):
+class TestAsnBarcode(DirectoriesMixin, TestCase):
 
     SAMPLE_DIR = Path(__file__).parent / "samples"
 
@@ -923,6 +934,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
 
@@ -944,6 +956,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
 
         asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
@@ -970,7 +983,13 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
         shutil.copy(test_file, dst)
 
         with mock.patch("documents.consumer.Consumer.try_consume_file") as mocked_call:
-            tasks.consume_file(dst)
+            tasks.consume_file(
+                ConsumableDocument(
+                    source=DocumentSource.ConsumeFolder,
+                    original_file=dst,
+                ),
+                None,
+            )
 
             args, kwargs = mocked_call.call_args
 
@@ -991,6 +1010,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
 
@@ -1010,6 +1030,7 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
 
         doc_barcode_info = barcodes.scan_file_for_barcodes(
             test_file,
+            "application/pdf",
         )
         asn = barcodes.get_asn_from_barcodes(doc_barcode_info.barcodes)
 
@@ -1032,12 +1053,17 @@ class TestAsnBarcodes(DirectoriesMixin, TestCase):
         dst = self.dirs.scratch_dir / "barcode-128-asn-too-large.pdf"
         shutil.copy(src, dst)
 
+        input_doc = ConsumableDocument(
+            source=DocumentSource.ConsumeFolder,
+            original_file=dst,
+        )
+
         with mock.patch("documents.consumer.Consumer._send_progress"):
             self.assertRaisesMessage(
                 ConsumerError,
                 "Given ASN 4294967296 is out of range [0, 4,294,967,295]",
                 tasks.consume_file,
-                dst,
+                input_doc,
             )
 
 
@@ -1055,5 +1081,5 @@ class TestBarcodeZxing(TestBarcode):
     reason="No zxingcpp",
 )
 @override_settings(CONSUMER_BARCODE_SCANNER="ZXING")
-class TestAsnBarcodesZxing(TestAsnBarcodes):
+class TestAsnBarcodesZxing(TestAsnBarcode):
     pass
index 3db8de0343e5686a018ac3ed68792c77033b25b7..637a8cb204170fef06a2e780501015f24a05386d 100644 (file)
@@ -1,6 +1,7 @@
 import filecmp
 import os
 import shutil
+from pathlib import Path
 from threading import Thread
 from time import sleep
 from unittest import mock
@@ -11,9 +12,12 @@ from django.core.management import CommandError
 from django.test import override_settings
 from django.test import TransactionTestCase
 from documents.consumer import ConsumerError
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
 from documents.management.commands import document_consumer
 from documents.models import Tag
 from documents.tests.utils import DirectoriesMixin
+from documents.tests.utils import DocumentConsumeDelayMixin
 
 
 class ConsumerThread(Thread):
@@ -35,18 +39,19 @@ def chunked(size, source):
         yield source[i : i + size]
 
 
-class ConsumerMixin:
+class ConsumerThreadMixin(DocumentConsumeDelayMixin):
+    """
+    Provides a thread which runs the consumer management command at setUp
+    and stops it at tearDown
+    """
 
-    sample_file = os.path.join(os.path.dirname(__file__), "samples", "simple.pdf")
+    sample_file: Path = (
+        Path(__file__).parent / Path("samples") / Path("simple.pdf")
+    ).resolve()
 
     def setUp(self) -> None:
         super().setUp()
         self.t = None
-        patcher = mock.patch(
-            "documents.tasks.consume_file.delay",
-        )
-        self.task_mock = patcher.start()
-        self.addCleanup(patcher.stop)
 
     def t_start(self):
         self.t = ConsumerThread()
@@ -67,7 +72,7 @@ class ConsumerMixin:
     def wait_for_task_mock_call(self, expected_call_count=1):
         n = 0
         while n < 50:
-            if self.task_mock.call_count >= expected_call_count:
+            if self.consume_file_mock.call_count >= expected_call_count:
                 # give task_mock some time to finish and raise errors
                 sleep(1)
                 return
@@ -76,8 +81,12 @@ class ConsumerMixin:
 
     # A bogus async_task that will simply check the file for
     # completeness and raise an exception otherwise.
-    def bogus_task(self, filename, **kwargs):
-        eq = filecmp.cmp(filename, self.sample_file, shallow=False)
+    def bogus_task(
+        self,
+        input_doc: ConsumableDocument,
+        overrides=None,
+    ):
+        eq = filecmp.cmp(input_doc.original_file, self.sample_file, shallow=False)
         if not eq:
             print("Consumed an INVALID file.")
             raise ConsumerError("Incomplete File READ FAILED")
@@ -103,19 +112,20 @@ class ConsumerMixin:
 @override_settings(
     CONSUMER_INOTIFY_DELAY=0.01,
 )
-class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
+class TestConsumer(DirectoriesMixin, ConsumerThreadMixin, TransactionTestCase):
     def test_consume_file(self):
         self.t_start()
 
-        f = os.path.join(self.dirs.consumption_dir, "my_file.pdf")
+        f = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf"))
         shutil.copy(self.sample_file, f)
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_called_once()
+        self.consume_file_mock.assert_called_once()
+
+        input_doc, _ = self.get_last_consume_delay_call_args()
 
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], f)
+        self.assertEqual(input_doc.original_file, f)
 
     def test_consume_file_invalid_ext(self):
         self.t_start()
@@ -125,26 +135,27 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
     def test_consume_existing_file(self):
-        f = os.path.join(self.dirs.consumption_dir, "my_file.pdf")
+        f = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf"))
         shutil.copy(self.sample_file, f)
 
         self.t_start()
-        self.task_mock.assert_called_once()
+        self.consume_file_mock.assert_called_once()
+
+        input_doc, _ = self.get_last_consume_delay_call_args()
 
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], f)
+        self.assertEqual(input_doc.original_file, f)
 
     @mock.patch("documents.management.commands.document_consumer.logger.error")
     def test_slow_write_pdf(self, error_logger):
 
-        self.task_mock.side_effect = self.bogus_task
+        self.consume_file_mock.side_effect = self.bogus_task
 
         self.t_start()
 
-        fname = os.path.join(self.dirs.consumption_dir, "my_file.pdf")
+        fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf"))
 
         self.slow_write_file(fname)
 
@@ -152,48 +163,52 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         error_logger.assert_not_called()
 
-        self.task_mock.assert_called_once()
+        self.consume_file_mock.assert_called_once()
+
+        input_doc, _ = self.get_last_consume_delay_call_args()
 
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], fname)
+        self.assertEqual(input_doc.original_file, fname)
 
     @mock.patch("documents.management.commands.document_consumer.logger.error")
     def test_slow_write_and_move(self, error_logger):
 
-        self.task_mock.side_effect = self.bogus_task
+        self.consume_file_mock.side_effect = self.bogus_task
 
         self.t_start()
 
-        fname = os.path.join(self.dirs.consumption_dir, "my_file.~df")
-        fname2 = os.path.join(self.dirs.consumption_dir, "my_file.pdf")
+        fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.~df"))
+        fname2 = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf"))
 
         self.slow_write_file(fname)
         shutil.move(fname, fname2)
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], fname2)
+        input_doc, _ = self.get_last_consume_delay_call_args()
+
+        self.assertEqual(input_doc.original_file, fname2)
 
         error_logger.assert_not_called()
 
     @mock.patch("documents.management.commands.document_consumer.logger.error")
     def test_slow_write_incomplete(self, error_logger):
 
-        self.task_mock.side_effect = self.bogus_task
+        self.consume_file_mock.side_effect = self.bogus_task
 
         self.t_start()
 
-        fname = os.path.join(self.dirs.consumption_dir, "my_file.pdf")
+        fname = Path(os.path.join(self.dirs.consumption_dir, "my_file.pdf"))
         self.slow_write_file(fname, incomplete=True)
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_called_once()
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], fname)
+        self.consume_file_mock.assert_called_once()
+
+        input_doc, _ = self.get_last_consume_delay_call_args()
+
+        self.assertEqual(input_doc.original_file, fname)
 
         # assert that we have an error logged with this invalid file.
         error_logger.assert_called_once()
@@ -209,7 +224,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
         self.assertRaises(CommandError, call_command, "document_consumer", "--oneshot")
 
     def test_mac_write(self):
-        self.task_mock.side_effect = self.bogus_task
+        self.consume_file_mock.side_effect = self.bogus_task
 
         self.t_start()
 
@@ -238,12 +253,13 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         self.wait_for_task_mock_call(expected_call_count=2)
 
-        self.assertEqual(2, self.task_mock.call_count)
+        self.assertEqual(2, self.consume_file_mock.call_count)
 
-        fnames = [
-            os.path.basename(args[0]) for args, _ in self.task_mock.call_args_list
-        ]
-        self.assertCountEqual(fnames, ["my_file.pdf", "my_second_file.pdf"])
+        consumed_files = []
+        for input_doc, _ in self.get_all_consume_delay_call_args():
+            consumed_files.append(input_doc.original_file.name)
+
+        self.assertCountEqual(consumed_files, ["my_file.pdf", "my_second_file.pdf"])
 
     def test_is_ignored(self):
         test_paths = [
@@ -341,7 +357,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_not_called()
+        self.consume_file_mock.assert_not_called()
 
 
 @override_settings(
@@ -373,7 +389,7 @@ class TestConsumerRecursivePolling(TestConsumer):
     pass
 
 
-class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
+class TestConsumerTags(DirectoriesMixin, ConsumerThreadMixin, TransactionTestCase):
     @override_settings(CONSUMER_RECURSIVE=True, CONSUMER_SUBDIRS_AS_TAGS=True)
     def test_consume_file_with_path_tags(self):
 
@@ -387,7 +403,7 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         path = os.path.join(self.dirs.consumption_dir, *tag_names)
         os.makedirs(path, exist_ok=True)
-        f = os.path.join(path, "my_file.pdf")
+        f = Path(os.path.join(path, "my_file.pdf"))
         # Wait at least inotify read_delay for recursive watchers
         # to be created for the new directories
         sleep(1)
@@ -395,18 +411,19 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
 
         self.wait_for_task_mock_call()
 
-        self.task_mock.assert_called_once()
+        self.consume_file_mock.assert_called_once()
 
         # Add the pk of the Tag created by _consume()
         tag_ids.append(Tag.objects.get(name=tag_names[1]).pk)
 
-        args, kwargs = self.task_mock.call_args
-        self.assertEqual(args[0], f)
+        input_doc, overrides = self.get_last_consume_delay_call_args()
+
+        self.assertEqual(input_doc.original_file, f)
 
         # assertCountEqual has a bad name, but test that the first
         # sequence contains the same elements as second, regardless of
         # their order.
-        self.assertCountEqual(kwargs["override_tag_ids"], tag_ids)
+        self.assertCountEqual(overrides.tag_ids, tag_ids)
 
     @override_settings(
         CONSUMER_POLLING=1,
index e21879802a214990d06b7d07139dda22ab511311..18600d709a6d23863fb97c837488f3f2ffcc588d 100644 (file)
@@ -1,76 +1,21 @@
+import uuid
+from unittest import mock
+
 import celery
 from django.test import TestCase
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.models import PaperlessTask
 from documents.signals.handlers import before_task_publish_handler
 from documents.signals.handlers import task_postrun_handler
 from documents.signals.handlers import task_prerun_handler
+from documents.tests.test_consumer import fake_magic_from_file
 from documents.tests.utils import DirectoriesMixin
 
 
+@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
 class TestTaskSignalHandler(DirectoriesMixin, TestCase):
-
-    HEADERS_CONSUME = {
-        "lang": "py",
-        "task": "documents.tasks.consume_file",
-        "id": "52d31e24-9dcc-4c32-9e16-76007e9add5e",
-        "shadow": None,
-        "eta": None,
-        "expires": None,
-        "group": None,
-        "group_index": None,
-        "retries": 0,
-        "timelimit": [None, None],
-        "root_id": "52d31e24-9dcc-4c32-9e16-76007e9add5e",
-        "parent_id": None,
-        "argsrepr": "('/consume/hello-999.pdf',)",
-        "kwargsrepr": "{'override_tag_ids': None}",
-        "origin": "gen260@paperless-ngx-dev-webserver",
-        "ignore_result": False,
-    }
-
-    BODY_CONSUME = (
-        # args
-        ("/consume/hello-999.pdf",),
-        # kwargs
-        {"override_tag_ids": None},
-        {"callbacks": None, "errbacks": None, "chain": None, "chord": None},
-    )
-
-    HEADERS_WEB_UI = {
-        "lang": "py",
-        "task": "documents.tasks.consume_file",
-        "id": "6e88a41c-e5f8-4631-9972-68c314512498",
-        "shadow": None,
-        "eta": None,
-        "expires": None,
-        "group": None,
-        "group_index": None,
-        "retries": 0,
-        "timelimit": [None, None],
-        "root_id": "6e88a41c-e5f8-4631-9972-68c314512498",
-        "parent_id": None,
-        "argsrepr": "('/tmp/paperless/paperless-upload-st9lmbvx',)",
-        "kwargsrepr": "{'override_filename': 'statement.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': 'f5622ca9-3707-4ed0-b418-9680b912572f', 'override_created': None}",
-        "origin": "gen342@paperless-ngx-dev-webserver",
-        "ignore_result": False,
-    }
-
-    BODY_WEB_UI = (
-        # args
-        ("/tmp/paperless/paperless-upload-st9lmbvx",),
-        # kwargs
-        {
-            "override_filename": "statement.pdf",
-            "override_title": None,
-            "override_correspondent_id": None,
-            "override_document_type_id": None,
-            "override_tag_ids": None,
-            "task_id": "f5622ca9-3707-4ed0-b418-9680b912572f",
-            "override_created": None,
-        },
-        {"callbacks": None, "errbacks": None, "chain": None, "chord": None},
-    )
-
     def util_call_before_task_publish_handler(self, headers_to_use, body_to_use):
         """
         Simple utility to call the pre-run handle and ensure it created a single task
@@ -91,38 +36,33 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
         THEN:
             - The task is created and marked as pending
         """
-        self.util_call_before_task_publish_handler(
-            headers_to_use=self.HEADERS_CONSUME,
-            body_to_use=self.BODY_CONSUME,
+        headers = {
+            "id": str(uuid.uuid4()),
+            "task": "documents.tasks.consume_file",
+        }
+        body = (
+            # args
+            (
+                ConsumableDocument(
+                    source=DocumentSource.ConsumeFolder,
+                    original_file="/consume/hello-999.pdf",
+                ),
+                None,
+            ),
+            # kwargs
+            {},
+            # celery stuff
+            {"callbacks": None, "errbacks": None, "chain": None, "chord": None},
         )
-
-        task = PaperlessTask.objects.get()
-        self.assertIsNotNone(task)
-        self.assertEqual(self.HEADERS_CONSUME["id"], task.task_id)
-        self.assertEqual("hello-999.pdf", task.task_file_name)
-        self.assertEqual("documents.tasks.consume_file", task.task_name)
-        self.assertEqual(celery.states.PENDING, task.status)
-
-    def test_before_task_publish_handler_webui(self):
-        """
-        GIVEN:
-            - A celery task is started via the web ui
-        WHEN:
-            - Task before publish handler is called
-        THEN:
-            - The task is created and marked as pending
-        """
         self.util_call_before_task_publish_handler(
-            headers_to_use=self.HEADERS_WEB_UI,
-            body_to_use=self.BODY_WEB_UI,
+            headers_to_use=headers,
+            body_to_use=body,
         )
 
         task = PaperlessTask.objects.get()
-
         self.assertIsNotNone(task)
-
-        self.assertEqual(self.HEADERS_WEB_UI["id"], task.task_id)
-        self.assertEqual("statement.pdf", task.task_file_name)
+        self.assertEqual(headers["id"], task.task_id)
+        self.assertEqual("hello-999.pdf", task.task_file_name)
         self.assertEqual("documents.tasks.consume_file", task.task_name)
         self.assertEqual(celery.states.PENDING, task.status)
 
@@ -135,12 +75,32 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
         THEN:
             - The task is marked as started
         """
+
+        headers = {
+            "id": str(uuid.uuid4()),
+            "task": "documents.tasks.consume_file",
+        }
+        body = (
+            # args
+            (
+                ConsumableDocument(
+                    source=DocumentSource.ConsumeFolder,
+                    original_file="/consume/hello-99.pdf",
+                ),
+                None,
+            ),
+            # kwargs
+            {},
+            # celery stuff
+            {"callbacks": None, "errbacks": None, "chain": None, "chord": None},
+        )
+
         self.util_call_before_task_publish_handler(
-            headers_to_use=self.HEADERS_CONSUME,
-            body_to_use=self.BODY_CONSUME,
+            headers_to_use=headers,
+            body_to_use=body,
         )
 
-        task_prerun_handler(task_id=self.HEADERS_CONSUME["id"])
+        task_prerun_handler(task_id=headers["id"])
 
         task = PaperlessTask.objects.get()
 
@@ -155,13 +115,31 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
         THEN:
             - The task is marked as started
         """
+        headers = {
+            "id": str(uuid.uuid4()),
+            "task": "documents.tasks.consume_file",
+        }
+        body = (
+            # args
+            (
+                ConsumableDocument(
+                    source=DocumentSource.ConsumeFolder,
+                    original_file="/consume/hello-9.pdf",
+                ),
+                None,
+            ),
+            # kwargs
+            {},
+            # celery stuff
+            {"callbacks": None, "errbacks": None, "chain": None, "chord": None},
+        )
         self.util_call_before_task_publish_handler(
-            headers_to_use=self.HEADERS_CONSUME,
-            body_to_use=self.BODY_CONSUME,
+            headers_to_use=headers,
+            body_to_use=body,
         )
 
         task_postrun_handler(
-            task_id=self.HEADERS_CONSUME["id"],
+            task_id=headers["id"],
             retval="Success. New document id 1 created",
             state=celery.states.SUCCESS,
         )
index 0a8da9ef952f258eb728f1b55887b0428d7be592..26760f7806bdcbee1e78a0c6a9707bedaa0a59e5 100644 (file)
@@ -4,6 +4,8 @@ from collections import namedtuple
 from contextlib import contextmanager
 from os import PathLike
 from pathlib import Path
+from typing import Iterator
+from typing import Tuple
 from typing import Union
 from unittest import mock
 
@@ -12,6 +14,8 @@ from django.db import connection
 from django.db.migrations.executor import MigrationExecutor
 from django.test import override_settings
 from django.test import TransactionTestCase
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
 
 
 def setup_directories():
@@ -116,6 +120,11 @@ class ConsumerProgressMixin:
 
 
 class DocumentConsumeDelayMixin:
+    """
+    Provides mocking of the consume_file asynchronous task and useful utilities
+    for decoding its arguments
+    """
+
     def setUp(self) -> None:
         self.consume_file_patcher = mock.patch("documents.tasks.consume_file.delay")
         self.consume_file_mock = self.consume_file_patcher.start()
@@ -125,6 +134,47 @@ class DocumentConsumeDelayMixin:
         super().tearDown()
         self.consume_file_patcher.stop()
 
+    def get_last_consume_delay_call_args(
+        self,
+    ) -> Tuple[ConsumableDocument, DocumentMetadataOverrides]:
+        """
+        Returns the most recent arguments to the async task
+        """
+        # Must be at least 1 call
+        self.consume_file_mock.assert_called()
+
+        args, _ = self.consume_file_mock.call_args
+        input_doc, overrides = args
+
+        return (input_doc, overrides)
+
+    def get_all_consume_delay_call_args(
+        self,
+    ) -> Iterator[Tuple[ConsumableDocument, DocumentMetadataOverrides]]:
+        """
+        Iterates over all calls to the async task and returns the arguments
+        """
+
+        for args, _ in self.consume_file_mock.call_args_list:
+            input_doc, overrides = args
+
+            yield (input_doc, overrides)
+
+    def get_specific_consume_delay_call_args(
+        self,
+        index: int,
+    ) -> Iterator[Tuple[ConsumableDocument, DocumentMetadataOverrides]]:
+        """
+        Returns the arguments of a specific call to the async task
+        """
+        # Must be at least 1 call
+        self.consume_file_mock.assert_called()
+
+        args, _ = self.consume_file_mock.call_args_list[index]
+        input_doc, overrides = args
+
+        return (input_doc, overrides)
+
 
 class TestMigrations(TransactionTestCase):
     @property
@@ -140,7 +190,7 @@ class TestMigrations(TransactionTestCase):
 
         assert (
             self.migrate_from and self.migrate_to
-        ), "TestCase '{}' must define migrate_from and migrate_to     properties".format(
+        ), "TestCase '{}' must define migrate_from and migrate_to properties".format(
             type(self).__name__,
         )
         self.migrate_from = [(self.app, self.migrate_from)]
index 1b30ec770aaa9e0d49777010beb251bdf23b9147..a50d9f7f40afb6784a0a00d27d89102cbbd1d84b 100644 (file)
@@ -5,7 +5,6 @@ import os
 import re
 import tempfile
 import urllib
-import uuid
 import zipfile
 from datetime import datetime
 from pathlib import Path
@@ -65,6 +64,9 @@ from .bulk_download import ArchiveOnlyStrategy
 from .bulk_download import OriginalAndArchiveStrategy
 from .bulk_download import OriginalsOnlyStrategy
 from .classifier import load_classifier
+from .data_models import ConsumableDocument
+from .data_models import DocumentMetadataOverrides
+from .data_models import DocumentSource
 from .filters import CorrespondentFilterSet
 from .filters import DocumentFilterSet
 from .filters import DocumentTypeFilterSet
@@ -692,19 +694,24 @@ class PostDocumentView(GenericAPIView):
 
         os.utime(temp_file_path, times=(t, t))
 
-        task_id = str(uuid.uuid4())
+        input_doc = ConsumableDocument(
+            source=DocumentSource.ApiUpload,
+            original_file=temp_file_path,
+        )
+        input_doc_overrides = DocumentMetadataOverrides(
+            filename=doc_name,
+            title=title,
+            correspondent_id=correspondent_id,
+            document_type_id=document_type_id,
+            tag_ids=tag_ids,
+            created=created,
+            asn=archive_serial_number,
+            owner_id=request.user.id,
+        )
 
         async_task = consume_file.delay(
-            # Paths are not JSON friendly
-            str(temp_file_path),
-            override_title=title,
-            override_correspondent_id=correspondent_id,
-            override_document_type_id=document_type_id,
-            override_tag_ids=tag_ids,
-            task_id=task_id,
-            override_created=created,
-            override_owner_id=request.user.id,
-            override_archive_serial_num=archive_serial_number,
+            input_doc,
+            input_doc_overrides,
         )
 
         return Response(async_task.id)
index 50a5785632d050619b54e936cdf978df55cc3529..06dd3ac6cb387fabc85dec3f99909dc94adba03a 100644 (file)
@@ -21,6 +21,9 @@ from django.conf import settings
 from django.db import DatabaseError
 from django.utils.timezone import is_naive
 from django.utils.timezone import make_aware
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.loggers import LoggingMixin
 from documents.models import Correspondent
 from documents.parsers import is_mime_type_supported
@@ -694,18 +697,22 @@ class MailAccountHandler(LoggingMixin):
                     f"{message.subject} from {message.from_}",
                 )
 
+                input_doc = ConsumableDocument(
+                    source=DocumentSource.MailFetch,
+                    original_file=temp_filename,
+                )
+                doc_overrides = DocumentMetadataOverrides(
+                    title=title,
+                    filename=pathvalidate.sanitize_filename(att.filename),
+                    correspondent_id=correspondent.id if correspondent else None,
+                    document_type_id=doc_type.id if doc_type else None,
+                    tag_ids=tag_ids,
+                    owner_id=rule.owner.id if rule.owner else None,
+                )
+
                 consume_task = consume_file.s(
-                    path=temp_filename,
-                    override_filename=pathvalidate.sanitize_filename(
-                        att.filename,
-                    ),
-                    override_title=title,
-                    override_correspondent_id=correspondent.id
-                    if correspondent
-                    else None,
-                    override_document_type_id=doc_type.id if doc_type else None,
-                    override_tag_ids=tag_ids,
-                    override_owner_id=rule.owner.id if rule.owner else None,
+                    input_doc,
+                    doc_overrides,
                 )
 
                 consume_tasks.append(consume_task)
@@ -770,16 +777,22 @@ class MailAccountHandler(LoggingMixin):
             f"{message.subject} from {message.from_}",
         )
 
+        input_doc = ConsumableDocument(
+            source=DocumentSource.MailFetch,
+            original_file=temp_filename,
+        )
+        doc_overrides = DocumentMetadataOverrides(
+            title=message.subject,
+            filename=pathvalidate.sanitize_filename(f"{message.subject}.eml"),
+            correspondent_id=correspondent.id if correspondent else None,
+            document_type_id=doc_type.id if doc_type else None,
+            tag_ids=tag_ids,
+            owner_id=rule.owner.id if rule.owner else None,
+        )
+
         consume_task = consume_file.s(
-            path=temp_filename,
-            override_filename=pathvalidate.sanitize_filename(
-                message.subject + ".eml",
-            ),
-            override_title=message.subject,
-            override_correspondent_id=correspondent.id if correspondent else None,
-            override_document_type_id=doc_type.id if doc_type else None,
-            override_tag_ids=tag_ids,
-            override_owner_id=rule.owner.id if rule.owner else None,
+            input_doc,
+            doc_overrides,
         )
 
         queue_consumption_tasks(
index c0bfccba5e592d3d99b8e4c08dc4998712ac5fec..e08f0ad1866d8843dd2e26ae36dcf0a80f66f91f 100644 (file)
@@ -12,8 +12,11 @@ from unittest import mock
 from django.core.management import call_command
 from django.db import DatabaseError
 from django.test import TestCase
+from documents.data_models import ConsumableDocument
+from documents.data_models import DocumentMetadataOverrides
 from documents.models import Correspondent
 from documents.tests.utils import DirectoriesMixin
+from documents.tests.utils import DocumentConsumeDelayMixin
 from documents.tests.utils import FileSystemAssertsMixin
 from imap_tools import EmailAddress
 from imap_tools import FolderInfo
@@ -194,7 +197,11 @@ def fake_magic_from_buffer(buffer, mime=False):
 
 
 @mock.patch("paperless_mail.mail.magic.from_buffer", fake_magic_from_buffer)
-class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
+class TestMail(
+    DirectoriesMixin,
+    FileSystemAssertsMixin,
+    TestCase,
+):
     def setUp(self):
         self._used_uids = set()
 
@@ -409,6 +416,8 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         self.assertEqual(result, 2)
 
+        self._queue_consumption_tasks_mock.assert_called()
+
         self.assert_queue_consumption_tasks_call_args(
             [
                 [
@@ -426,7 +435,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         result = self.mail_account_handler._handle_message(message, rule)
 
-        self.assertFalse(self._queue_consumption_tasks_mock.called)
+        self._queue_consumption_tasks_mock.assert_not_called()
         self.assertEqual(result, 0)
 
     def test_handle_unknown_mime_type(self):
@@ -541,7 +550,6 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         for (pattern, matches) in tests:
             with self.subTest(msg=pattern):
-                print(f"PATTERN {pattern}")
                 self._queue_consumption_tasks_mock.reset_mock()
                 account = MailAccount(name=str(uuid.uuid4()))
                 account.save()
@@ -855,7 +863,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.mail_account_handler.handle_mail_account(account)
 
         self.bogus_mailbox.folder.list.assert_called_once()
-        self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0)
+        self._queue_consumption_tasks_mock.assert_not_called()
 
     def test_error_folder_set_error_listing(self):
         """
@@ -888,7 +896,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.mail_account_handler.handle_mail_account(account)
 
         self.bogus_mailbox.folder.list.assert_called_once()
-        self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0)
+        self._queue_consumption_tasks_mock.assert_not_called()
 
     @mock.patch("paperless_mail.mail.MailAccountHandler._get_correspondent")
     def test_error_skip_mail(self, m):
@@ -1002,7 +1010,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
                 self.reset_bogus_mailbox()
                 self._queue_consumption_tasks_mock.reset_mock()
 
-                self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0)
+                self._queue_consumption_tasks_mock.assert_not_called()
                 self.assertEqual(len(self.bogus_mailbox.messages), 3)
 
                 self.mail_account_handler.handle_mail_account(account)
@@ -1041,7 +1049,7 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         )
 
         self.assertEqual(len(self.bogus_mailbox.messages), 3)
-        self.assertEqual(self._queue_consumption_tasks_mock.call_count, 0)
+        self._queue_consumption_tasks_mock.assert_not_called()
         self.assertEqual(len(self.bogus_mailbox.fetch("UNSEEN", False)), 2)
 
         self.mail_account_handler.handle_mail_account(account)
@@ -1148,13 +1156,21 @@ class TestMail(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
                 consume_tasks,
                 expected_signatures,
             ):
+                input_doc, overrides = consume_task.args
+
                 # assert the file exists
-                self.assertIsFile(consume_task.kwargs["path"])
+                self.assertIsFile(input_doc.original_file)
 
                 # assert all expected arguments are present in the signature
                 for key, value in expected_signature.items():
-                    self.assertIn(key, consume_task.kwargs)
-                    self.assertEqual(consume_task.kwargs[key], value)
+                    if key == "override_correspondent_id":
+                        self.assertEqual(overrides.correspondent_id, value)
+                    elif key == "override_filename":
+                        self.assertEqual(overrides.filename, value)
+                    elif key == "override_title":
+                        self.assertEqual(overrides.title, value)
+                    else:
+                        self.fail("No match for expected arg")
 
     def apply_mail_actions(self):
         """