]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Chore: Convert the consumer to a plugin (#6361)
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Thu, 18 Apr 2024 02:59:14 +0000 (19:59 -0700)
committerGitHub <noreply@github.com>
Thu, 18 Apr 2024 02:59:14 +0000 (02:59 +0000)
15 files changed:
paperless-ngx.code-workspace [new file with mode: 0644]
src-ui/src/app/services/consumer-status.service.ts
src/documents/consumer.py
src/documents/loggers.py
src/documents/parsers.py
src/documents/plugins/base.py
src/documents/plugins/helpers.py
src/documents/tasks.py
src/documents/tests/test_barcodes.py
src/documents/tests/test_consumer.py
src/documents/tests/test_double_sided.py
src/documents/tests/test_workflows.py
src/documents/tests/utils.py
src/paperless/settings.py
src/paperless_mail/mail.py

diff --git a/paperless-ngx.code-workspace b/paperless-ngx.code-workspace
new file mode 100644 (file)
index 0000000..7030ae6
--- /dev/null
@@ -0,0 +1,36 @@
+{
+       "folders": [
+               {
+                       "path": "."
+               },
+               {
+                       "path": "./src",
+                       "name": "Backend"
+               },
+               {
+                       "path": "./src-ui",
+                       "name": "Frontend"
+               },
+               {
+                       "path": "./.github",
+                       "name": "CI/CD"
+               },
+               {
+                       "path": "./docs",
+                       "name": "Documentation"
+               }
+
+       ],
+       "settings": {
+               "files.exclude": {
+                       "**/__pycache__": true,
+                       "**/.mypy_cache": true,
+                       "**/.ruff_cache": true,
+                       "**/.pytest_cache": true,
+                       "**/.idea": true,
+                       "**/.venv": true,
+                       "**/.coverage": true,
+                       "**/coverage.json": true
+               }
+       }
+}
index 246ddad694c7cf643ee4fa56ecca86e00cda0504..d8e8ffe28c26e658eaa2392a7ee818ab290f500a 100644 (file)
@@ -4,7 +4,7 @@ import { environment } from 'src/environments/environment'
 import { WebsocketConsumerStatusMessage } from '../data/websocket-consumer-status-message'
 import { SettingsService } from './settings.service'
 
-// see ConsumerFilePhase in src/documents/consumer.py
+// see ProgressStatusOptions in src/documents/plugins/helpers.py
 export enum FileStatusPhase {
   STARTED = 0,
   UPLOADING = 1,
index c735ed4c814b2fca4e617e1c33cb1ae0724d64e2..0d5514e2c78d26beab9ad9a0d530a4aa5ed655cb 100644 (file)
@@ -2,15 +2,13 @@ import datetime
 import hashlib
 import os
 import tempfile
-import uuid
 from enum import Enum
 from pathlib import Path
 from typing import TYPE_CHECKING
 from typing import Optional
+from typing import Union
 
 import magic
-from asgiref.sync import async_to_sync
-from channels.layers import get_channel_layer
 from django.conf import settings
 from django.contrib.auth.models import User
 from django.db import transaction
@@ -20,6 +18,7 @@ from filelock import FileLock
 from rest_framework.reverse import reverse
 
 from documents.classifier import load_classifier
+from documents.data_models import ConsumableDocument
 from documents.data_models import DocumentMetadataOverrides
 from documents.file_handling import create_source_path_directory
 from documents.file_handling import generate_unique_filename
@@ -45,6 +44,8 @@ from documents.plugins.base import AlwaysRunPluginMixin
 from documents.plugins.base import ConsumeTaskPlugin
 from documents.plugins.base import NoCleanupPluginMixin
 from documents.plugins.base import NoSetupPluginMixin
+from documents.plugins.helpers import ProgressManager
+from documents.plugins.helpers import ProgressStatusOptions
 from documents.signals import document_consumption_finished
 from documents.signals import document_consumption_started
 from documents.utils import copy_basic_file_stats
@@ -247,88 +248,81 @@ class ConsumerStatusShortMessage(str, Enum):
     FAILED = "failed"
 
 
-class ConsumerFilePhase(str, Enum):
-    STARTED = "STARTED"
-    WORKING = "WORKING"
-    SUCCESS = "SUCCESS"
-    FAILED = "FAILED"
+class ConsumerPlugin(
+    AlwaysRunPluginMixin,
+    NoSetupPluginMixin,
+    NoCleanupPluginMixin,
+    LoggingMixin,
+    ConsumeTaskPlugin,
+):
+    logging_name = "paperless.consumer"
 
+    def __init__(
+        self,
+        input_doc: ConsumableDocument,
+        metadata: DocumentMetadataOverrides,
+        status_mgr: ProgressManager,
+        base_tmp_dir: Path,
+        task_id: str,
+    ) -> None:
+        super().__init__(input_doc, metadata, status_mgr, base_tmp_dir, task_id)
 
-class Consumer(LoggingMixin):
-    logging_name = "paperless.consumer"
+        self.renew_logging_group()
+
+        self.filename = self.metadata.filename or self.input_doc.original_file.name
 
     def _send_progress(
         self,
         current_progress: int,
         max_progress: int,
-        status: ConsumerFilePhase,
-        message: Optional[ConsumerStatusShortMessage] = None,
+        status: ProgressStatusOptions,
+        message: Optional[Union[ConsumerStatusShortMessage, str]] = None,
         document_id=None,
     ):  # pragma: no cover
-        payload = {
-            "filename": os.path.basename(self.filename) if self.filename else None,
-            "task_id": self.task_id,
-            "current_progress": current_progress,
-            "max_progress": max_progress,
-            "status": status,
-            "message": message,
-            "document_id": document_id,
-            "owner_id": self.override_owner_id if self.override_owner_id else None,
-        }
-        async_to_sync(self.channel_layer.group_send)(
-            "status_updates",
-            {"type": "status_update", "data": payload},
+        self.status_mgr.send_progress(
+            status,
+            message,
+            current_progress,
+            max_progress,
+            extra_args={
+                "document_id": document_id,
+                "owner_id": self.metadata.owner_id if self.metadata.owner_id else None,
+            },
         )
 
     def _fail(
         self,
-        message: ConsumerStatusShortMessage,
+        message: Union[ConsumerStatusShortMessage, str],
         log_message: Optional[str] = None,
         exc_info=None,
         exception: Optional[Exception] = None,
     ):
-        self._send_progress(100, 100, ConsumerFilePhase.FAILED, message)
+        self._send_progress(100, 100, ProgressStatusOptions.FAILED, message)
         self.log.error(log_message or message, exc_info=exc_info)
         raise ConsumerError(f"{self.filename}: {log_message or message}") from exception
 
-    def __init__(self):
-        super().__init__()
-        self.path: Optional[Path] = None
-        self.original_path: Optional[Path] = None
-        self.filename = None
-        self.override_title = None
-        self.override_correspondent_id = None
-        self.override_tag_ids = None
-        self.override_document_type_id = None
-        self.override_asn = None
-        self.task_id = None
-        self.override_owner_id = None
-        self.override_custom_field_ids = None
-
-        self.channel_layer = get_channel_layer()
-
     def pre_check_file_exists(self):
         """
         Confirm the input file still exists where it should
         """
-        if not os.path.isfile(self.original_path):
+        if not os.path.isfile(self.input_doc.original_file):
             self._fail(
                 ConsumerStatusShortMessage.FILE_NOT_FOUND,
-                f"Cannot consume {self.original_path}: File not found.",
+                f"Cannot consume {self.input_doc.original_file}: File not found.",
             )
 
     def pre_check_duplicate(self):
         """
         Using the MD5 of the file, check this exact file doesn't already exist
         """
-        with open(self.original_path, "rb") as f:
+        with open(self.input_doc.original_file, "rb") as f:
             checksum = hashlib.md5(f.read()).hexdigest()
         existing_doc = Document.objects.filter(
             Q(checksum=checksum) | Q(archive_checksum=checksum),
         )
         if existing_doc.exists():
             if settings.CONSUMER_DELETE_DUPLICATES:
-                os.unlink(self.original_path)
+                os.unlink(self.input_doc.original_file)
             self._fail(
                 ConsumerStatusShortMessage.DOCUMENT_ALREADY_EXISTS,
                 f"Not consuming {self.filename}: It is a duplicate of"
@@ -348,26 +342,26 @@ class Consumer(LoggingMixin):
         """
         Check that if override_asn is given, it is unique and within a valid range
         """
-        if not self.override_asn:
+        if not self.metadata.asn:
             # check not necessary in case no ASN gets set
             return
         # Validate the range is above zero and less than uint32_t max
         # otherwise, Whoosh can't handle it in the index
         if (
-            self.override_asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
-            or self.override_asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
+            self.metadata.asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
+            or self.metadata.asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
         ):
             self._fail(
                 ConsumerStatusShortMessage.ASN_RANGE,
                 f"Not consuming {self.filename}: "
-                f"Given ASN {self.override_asn} is out of range "
+                f"Given ASN {self.metadata.asn} is out of range "
                 f"[{Document.ARCHIVE_SERIAL_NUMBER_MIN:,}, "
                 f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}]",
             )
-        if Document.objects.filter(archive_serial_number=self.override_asn).exists():
+        if Document.objects.filter(archive_serial_number=self.metadata.asn).exists():
             self._fail(
                 ConsumerStatusShortMessage.ASN_ALREADY_EXISTS,
-                f"Not consuming {self.filename}: Given ASN {self.override_asn} already exists!",
+                f"Not consuming {self.filename}: Given ASN {self.metadata.asn} already exists!",
             )
 
     def run_pre_consume_script(self):
@@ -388,7 +382,7 @@ class Consumer(LoggingMixin):
         self.log.info(f"Executing pre-consume script {settings.PRE_CONSUME_SCRIPT}")
 
         working_file_path = str(self.working_copy)
-        original_file_path = str(self.original_path)
+        original_file_path = str(self.input_doc.original_file)
 
         script_env = os.environ.copy()
         script_env["DOCUMENT_SOURCE_PATH"] = original_file_path
@@ -486,50 +480,15 @@ class Consumer(LoggingMixin):
                 exception=e,
             )
 
-    def try_consume_file(
-        self,
-        path: Path,
-        override_filename=None,
-        override_title=None,
-        override_correspondent_id=None,
-        override_document_type_id=None,
-        override_tag_ids=None,
-        override_storage_path_id=None,
-        task_id=None,
-        override_created=None,
-        override_asn=None,
-        override_owner_id=None,
-        override_view_users=None,
-        override_view_groups=None,
-        override_change_users=None,
-        override_change_groups=None,
-        override_custom_field_ids=None,
-    ) -> Document:
+    def run(self) -> str:
         """
         Return the document object if it was successfully created.
         """
 
-        self.original_path = Path(path).resolve()
-        self.filename = override_filename or self.original_path.name
-        self.override_title = override_title
-        self.override_correspondent_id = override_correspondent_id
-        self.override_document_type_id = override_document_type_id
-        self.override_tag_ids = override_tag_ids
-        self.override_storage_path_id = override_storage_path_id
-        self.task_id = task_id or str(uuid.uuid4())
-        self.override_created = override_created
-        self.override_asn = override_asn
-        self.override_owner_id = override_owner_id
-        self.override_view_users = override_view_users
-        self.override_view_groups = override_view_groups
-        self.override_change_users = override_change_users
-        self.override_change_groups = override_change_groups
-        self.override_custom_field_ids = override_custom_field_ids
-
         self._send_progress(
             0,
             100,
-            ConsumerFilePhase.STARTED,
+            ProgressStatusOptions.STARTED,
             ConsumerStatusShortMessage.NEW_FILE,
         )
 
@@ -548,7 +507,7 @@ class Consumer(LoggingMixin):
             dir=settings.SCRATCH_DIR,
         )
         self.working_copy = Path(tempdir.name) / Path(self.filename)
-        copy_file_with_basic_stats(self.original_path, self.working_copy)
+        copy_file_with_basic_stats(self.input_doc.original_file, self.working_copy)
 
         # Determine the parser class.
 
@@ -580,7 +539,7 @@ class Consumer(LoggingMixin):
         def progress_callback(current_progress, max_progress):  # pragma: no cover
             # recalculate progress to be within 20 and 80
             p = int((current_progress / max_progress) * 50 + 20)
-            self._send_progress(p, 100, ConsumerFilePhase.WORKING)
+            self._send_progress(p, 100, ProgressStatusOptions.WORKING)
 
         # This doesn't parse the document yet, but gives us a parser.
 
@@ -591,9 +550,6 @@ class Consumer(LoggingMixin):
 
         self.log.debug(f"Parser: {type(document_parser).__name__}")
 
-        # However, this already created working directories which we have to
-        # clean up.
-
         # Parse the document. This may take some time.
 
         text = None
@@ -605,7 +561,7 @@ class Consumer(LoggingMixin):
             self._send_progress(
                 20,
                 100,
-                ConsumerFilePhase.WORKING,
+                ProgressStatusOptions.WORKING,
                 ConsumerStatusShortMessage.PARSING_DOCUMENT,
             )
             self.log.debug(f"Parsing {self.filename}...")
@@ -615,7 +571,7 @@ class Consumer(LoggingMixin):
             self._send_progress(
                 70,
                 100,
-                ConsumerFilePhase.WORKING,
+                ProgressStatusOptions.WORKING,
                 ConsumerStatusShortMessage.GENERATING_THUMBNAIL,
             )
             thumbnail = document_parser.get_thumbnail(
@@ -630,7 +586,7 @@ class Consumer(LoggingMixin):
                 self._send_progress(
                     90,
                     100,
-                    ConsumerFilePhase.WORKING,
+                    ProgressStatusOptions.WORKING,
                     ConsumerStatusShortMessage.PARSE_DATE,
                 )
                 date = parse_date(self.filename, text)
@@ -664,7 +620,7 @@ class Consumer(LoggingMixin):
         self._send_progress(
             95,
             100,
-            ConsumerFilePhase.WORKING,
+            ProgressStatusOptions.WORKING,
             ConsumerStatusShortMessage.SAVE_DOCUMENT,
         )
         # now that everything is done, we can start to store the document
@@ -726,13 +682,13 @@ class Consumer(LoggingMixin):
 
                 # Delete the file only if it was successfully consumed
                 self.log.debug(f"Deleting file {self.working_copy}")
-                self.original_path.unlink()
+                self.input_doc.original_file.unlink()
                 self.working_copy.unlink()
 
                 # https://github.com/jonaswinkler/paperless-ng/discussions/1037
                 shadow_file = os.path.join(
-                    os.path.dirname(self.original_path),
-                    "._" + os.path.basename(self.original_path),
+                    os.path.dirname(self.input_doc.original_file),
+                    "._" + os.path.basename(self.input_doc.original_file),
                 )
 
                 if os.path.isfile(shadow_file):
@@ -758,7 +714,7 @@ class Consumer(LoggingMixin):
         self._send_progress(
             100,
             100,
-            ConsumerFilePhase.SUCCESS,
+            ProgressStatusOptions.SUCCESS,
             ConsumerStatusShortMessage.FINISHED,
             document.id,
         )
@@ -766,24 +722,24 @@ class Consumer(LoggingMixin):
         # Return the most up to date fields
         document.refresh_from_db()
 
-        return document
+        return f"Success. New document id {document.pk} created"
 
     def _parse_title_placeholders(self, title: str) -> str:
         local_added = timezone.localtime(timezone.now())
 
         correspondent_name = (
-            Correspondent.objects.get(pk=self.override_correspondent_id).name
-            if self.override_correspondent_id is not None
+            Correspondent.objects.get(pk=self.metadata.correspondent_id).name
+            if self.metadata.correspondent_id is not None
             else None
         )
         doc_type_name = (
-            DocumentType.objects.get(pk=self.override_document_type_id).name
-            if self.override_document_type_id is not None
+            DocumentType.objects.get(pk=self.metadata.document_type_id).name
+            if self.metadata.document_type_id is not None
             else None
         )
         owner_username = (
-            User.objects.get(pk=self.override_owner_id).username
-            if self.override_owner_id is not None
+            User.objects.get(pk=self.metadata.owner_id).username
+            if self.metadata.owner_id is not None
             else None
         )
 
@@ -808,8 +764,8 @@ class Consumer(LoggingMixin):
 
         self.log.debug("Saving record to database")
 
-        if self.override_created is not None:
-            create_date = self.override_created
+        if self.metadata.created is not None:
+            create_date = self.metadata.created
             self.log.debug(
                 f"Creation date from post_documents parameter: {create_date}",
             )
@@ -820,7 +776,7 @@ class Consumer(LoggingMixin):
             create_date = date
             self.log.debug(f"Creation date from parse_date: {create_date}")
         else:
-            stats = os.stat(self.original_path)
+            stats = os.stat(self.input_doc.original_file)
             create_date = timezone.make_aware(
                 datetime.datetime.fromtimestamp(stats.st_mtime),
             )
@@ -829,12 +785,12 @@ class Consumer(LoggingMixin):
         storage_type = Document.STORAGE_TYPE_UNENCRYPTED
 
         title = file_info.title
-        if self.override_title is not None:
+        if self.metadata.title is not None:
             try:
-                title = self._parse_title_placeholders(self.override_title)
+                title = self._parse_title_placeholders(self.metadata.title)
             except Exception as e:
                 self.log.error(
-                    f"Error occurred parsing title override '{self.override_title}', falling back to original. Exception: {e}",
+                    f"Error occurred parsing title override '{self.metadata.title}', falling back to original. Exception: {e}",
                 )
 
         document = Document.objects.create(
@@ -855,53 +811,53 @@ class Consumer(LoggingMixin):
         return document
 
     def apply_overrides(self, document):
-        if self.override_correspondent_id:
+        if self.metadata.correspondent_id:
             document.correspondent = Correspondent.objects.get(
-                pk=self.override_correspondent_id,
+                pk=self.metadata.correspondent_id,
             )
 
-        if self.override_document_type_id:
+        if self.metadata.document_type_id:
             document.document_type = DocumentType.objects.get(
-                pk=self.override_document_type_id,
+                pk=self.metadata.document_type_id,
             )
 
-        if self.override_tag_ids:
-            for tag_id in self.override_tag_ids:
+        if self.metadata.tag_ids:
+            for tag_id in self.metadata.tag_ids:
                 document.tags.add(Tag.objects.get(pk=tag_id))
 
-        if self.override_storage_path_id:
+        if self.metadata.storage_path_id:
             document.storage_path = StoragePath.objects.get(
-                pk=self.override_storage_path_id,
+                pk=self.metadata.storage_path_id,
             )
 
-        if self.override_asn:
-            document.archive_serial_number = self.override_asn
+        if self.metadata.asn:
+            document.archive_serial_number = self.metadata.asn
 
-        if self.override_owner_id:
+        if self.metadata.owner_id:
             document.owner = User.objects.get(
-                pk=self.override_owner_id,
+                pk=self.metadata.owner_id,
             )
 
         if (
-            self.override_view_users is not None
-            or self.override_view_groups is not None
-            or self.override_change_users is not None
-            or self.override_change_users is not None
+            self.metadata.view_users is not None
+            or self.metadata.view_groups is not None
+            or self.metadata.change_users is not None
+            or self.metadata.change_users is not None
         ):
             permissions = {
                 "view": {
-                    "users": self.override_view_users or [],
-                    "groups": self.override_view_groups or [],
+                    "users": self.metadata.view_users or [],
+                    "groups": self.metadata.view_groups or [],
                 },
                 "change": {
-                    "users": self.override_change_users or [],
-                    "groups": self.override_change_groups or [],
+                    "users": self.metadata.change_users or [],
+                    "groups": self.metadata.change_groups or [],
                 },
             }
             set_permissions_for_object(permissions=permissions, object=document)
 
-        if self.override_custom_field_ids:
-            for field_id in self.override_custom_field_ids:
+        if self.metadata.custom_field_ids:
+            for field_id in self.metadata.custom_field_ids:
                 field = CustomField.objects.get(pk=field_id)
                 CustomFieldInstance.objects.create(
                     field=field,
index 0fc5cc162abacb9ec8b1a9882209b31ab443c376..87ee58868d21c469e99c54ee6d4dcbde10286931 100644 (file)
@@ -3,9 +3,6 @@ import uuid
 
 
 class LoggingMixin:
-    def __init__(self) -> None:
-        self.renew_logging_group()
-
     def renew_logging_group(self):
         """
         Creates a new UUID to group subsequent log calls together with
index d781ddb9f2c5a0757512623feab1ca0ea05e1075..ed70f653d3fdcf8d0743100a0f8081b107544bc7 100644 (file)
@@ -328,6 +328,7 @@ class DocumentParser(LoggingMixin):
 
     def __init__(self, logging_group, progress_callback=None):
         super().__init__()
+        self.renew_logging_group()
         self.logging_group = logging_group
         self.settings = self.get_settings()
         settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
index aec4887be994179a366abd63cb7652890dff5028..14d6ea69600cd21eab21d8c47f710d3c8ba91d20 100644 (file)
@@ -67,7 +67,8 @@ class ConsumeTaskPlugin(abc.ABC):
         self.status_mgr = status_mgr
         self.task_id: Final = task_id
 
-    @abc.abstractproperty
+    @property
+    @abc.abstractmethod
     def able_to_run(self) -> bool:
         """
         Return True if the conditions are met for the plugin to run, False otherwise
index 27d03f30f97c267a7efda0e70fbcf53172bef488..2d3686db48fecee11cc06f5c2a748b64a14ec4e5 100644 (file)
@@ -57,7 +57,7 @@ class ProgressManager:
         message: str,
         current_progress: int,
         max_progress: int,
-        extra_args: Optional[dict[str, Union[str, int]]] = None,
+        extra_args: Optional[dict[str, Union[str, int, None]]] = None,
     ) -> None:
         # Ensure the layer is open
         self.open()
index 0ab55ac4564f2bef70f55271e7591e7e30f459db..1bc812bfd1733bf25346b1e14ef2dcad34635141 100644 (file)
@@ -21,8 +21,7 @@ from documents.barcodes import BarcodePlugin
 from documents.caching import clear_document_caches
 from documents.classifier import DocumentClassifier
 from documents.classifier import load_classifier
-from documents.consumer import Consumer
-from documents.consumer import ConsumerError
+from documents.consumer import ConsumerPlugin
 from documents.consumer import WorkflowTriggerPlugin
 from documents.data_models import ConsumableDocument
 from documents.data_models import DocumentMetadataOverrides
@@ -115,6 +114,7 @@ def consume_file(
         CollatePlugin,
         BarcodePlugin,
         WorkflowTriggerPlugin,
+        ConsumerPlugin,
     ]
 
     with ProgressManager(
@@ -162,33 +162,7 @@ def consume_file(
             finally:
                 plugin.cleanup()
 
-    # continue with consumption if no barcode was found
-    document = Consumer().try_consume_file(
-        input_doc.original_file,
-        override_filename=overrides.filename,
-        override_title=overrides.title,
-        override_correspondent_id=overrides.correspondent_id,
-        override_document_type_id=overrides.document_type_id,
-        override_tag_ids=overrides.tag_ids,
-        override_storage_path_id=overrides.storage_path_id,
-        override_created=overrides.created,
-        override_asn=overrides.asn,
-        override_owner_id=overrides.owner_id,
-        override_view_users=overrides.view_users,
-        override_view_groups=overrides.view_groups,
-        override_change_users=overrides.change_users,
-        override_change_groups=overrides.change_groups,
-        override_custom_field_ids=overrides.custom_field_ids,
-        task_id=self.request.id,
-    )
-
-    if document:
-        return f"Success. New document id {document.pk} created"
-    else:
-        raise ConsumerError(
-            "Unknown error: Returned document was null, but "
-            "no error message was given.",
-        )
+    return msg
 
 
 @shared_task
index 2f4f5cd39f55d9d4f5c663127d8ad63a001114ff..b0c42963a45c433780c3aa3b2047a603a7a058f2 100644 (file)
@@ -14,6 +14,7 @@ from documents.barcodes import BarcodePlugin
 from documents.data_models import ConsumableDocument
 from documents.data_models import DocumentMetadataOverrides
 from documents.data_models import DocumentSource
+from documents.models import Document
 from documents.models import Tag
 from documents.plugins.base import StopConsumeTaskError
 from documents.tests.utils import DirectoriesMixin
@@ -674,9 +675,7 @@ class TestAsnBarcode(DirectoriesMixin, SampleDirMixin, GetReaderPluginMixin, Tes
         dst = settings.SCRATCH_DIR / "barcode-39-asn-123.pdf"
         shutil.copy(test_file, dst)
 
-        with mock.patch(
-            "documents.consumer.Consumer.try_consume_file",
-        ) as mocked_consumer:
+        with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             tasks.consume_file(
                 ConsumableDocument(
                     source=DocumentSource.ConsumeFolder,
@@ -684,10 +683,10 @@ class TestAsnBarcode(DirectoriesMixin, SampleDirMixin, GetReaderPluginMixin, Tes
                 ),
                 None,
             )
-            mocked_consumer.assert_called_once()
-            args, kwargs = mocked_consumer.call_args
 
-            self.assertEqual(kwargs["override_asn"], 123)
+            document = Document.objects.first()
+
+            self.assertEqual(document.archive_serial_number, 123)
 
     @override_settings(CONSUMER_BARCODE_SCANNER="PYZBAR")
     def test_scan_file_for_qrcode_without_upscale(self):
index 6c23df8aab737467de10dd7b40a3dbdbf27a58ec..3874ebac6282f3fa56cf548378ab9e6e400a517e 100644 (file)
@@ -4,8 +4,9 @@ import re
 import shutil
 import stat
 import tempfile
-import uuid
 import zoneinfo
+from pathlib import Path
+from unittest import TestCase as UnittestTestCase
 from unittest import mock
 from unittest.mock import MagicMock
 
@@ -18,9 +19,8 @@ from django.test import override_settings
 from django.utils import timezone
 from guardian.core import ObjectPermissionChecker
 
-from documents.consumer import Consumer
 from documents.consumer import ConsumerError
-from documents.consumer import ConsumerFilePhase
+from documents.data_models import DocumentMetadataOverrides
 from documents.models import Correspondent
 from documents.models import CustomField
 from documents.models import Document
@@ -30,12 +30,14 @@ from documents.models import StoragePath
 from documents.models import Tag
 from documents.parsers import DocumentParser
 from documents.parsers import ParseError
+from documents.plugins.helpers import ProgressStatusOptions
 from documents.tasks import sanity_check
 from documents.tests.utils import DirectoriesMixin
 from documents.tests.utils import FileSystemAssertsMixin
+from documents.tests.utils import GetConsumerMixin
 
 
-class TestAttributes(TestCase):
+class TestAttributes(UnittestTestCase):
     TAGS = ("tag1", "tag2", "tag3")
 
     def _test_guess_attributes_from_name(self, filename, sender, title, tags):
@@ -246,29 +248,33 @@ def fake_magic_from_file(file, mime=False):
 
 
 @mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
-class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
+class TestConsumer(
+    DirectoriesMixin,
+    FileSystemAssertsMixin,
+    GetConsumerMixin,
+    TestCase,
+):
     def _assert_first_last_send_progress(
         self,
-        first_status=ConsumerFilePhase.STARTED,
-        last_status=ConsumerFilePhase.SUCCESS,
+        first_status=ProgressStatusOptions.STARTED,
+        last_status=ProgressStatusOptions.SUCCESS,
         first_progress=0,
         first_progress_max=100,
         last_progress=100,
         last_progress_max=100,
     ):
-        self._send_progress.assert_called()
+        self.assertGreaterEqual(len(self.status.payloads), 2)
 
-        args, kwargs = self._send_progress.call_args_list[0]
-        self.assertEqual(args[0], first_progress)
-        self.assertEqual(args[1], first_progress_max)
-        self.assertEqual(args[2], first_status)
+        payload = self.status.payloads[0]
+        self.assertEqual(payload["data"]["current_progress"], first_progress)
+        self.assertEqual(payload["data"]["max_progress"], first_progress_max)
+        self.assertEqual(payload["data"]["status"], first_status)
 
-        args, kwargs = self._send_progress.call_args_list[
-            len(self._send_progress.call_args_list) - 1
-        ]
-        self.assertEqual(args[0], last_progress)
-        self.assertEqual(args[1], last_progress_max)
-        self.assertEqual(args[2], last_status)
+        payload = self.status.payloads[-1]
+
+        self.assertEqual(payload["data"]["current_progress"], last_progress)
+        self.assertEqual(payload["data"]["max_progress"], last_progress_max)
+        self.assertEqual(payload["data"]["status"], last_status)
 
     def make_dummy_parser(self, logging_group, progress_callback=None):
         return DummyParser(
@@ -304,34 +310,23 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         ]
         self.addCleanup(patcher.stop)
 
-        # this prevents websocket message reports during testing.
-        patcher = mock.patch("documents.consumer.Consumer._send_progress")
-        self._send_progress = patcher.start()
-        self.addCleanup(patcher.stop)
-
-        self.consumer = Consumer()
-
     def get_test_file(self):
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "originals",
-            "0000001.pdf",
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000001.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "sample.pdf")
+        dst = self.dirs.scratch_dir / "sample.pdf"
         shutil.copy(src, dst)
         return dst
 
     def get_test_archive_file(self):
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "archive",
-            "0000001.pdf",
+        src = (
+            Path(__file__).parent / "samples" / "documents" / "archive" / "0000001.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "sample_archive.pdf")
+        dst = self.dirs.scratch_dir / "sample_archive.pdf"
         shutil.copy(src, dst)
         return dst
 
@@ -343,8 +338,12 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         # Roughly equal to file modification time
         rough_create_date_local = timezone.localtime(timezone.now())
 
-        # Consume the file
-        document = self.consumer.try_consume_file(filename)
+        with self.get_consumer(filename) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self.assertIsNotNone(document)
 
         self.assertEqual(document.content, "The Text")
         self.assertEqual(
@@ -395,7 +394,12 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         self.assertIsFile(shadow_file)
 
-        document = self.consumer.try_consume_file(filename)
+        with self.get_consumer(filename) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self.assertIsNotNone(document)
 
         self.assertIsFile(document.source_path)
 
@@ -406,29 +410,48 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         filename = self.get_test_file()
         override_filename = "Statement for November.pdf"
 
-        document = self.consumer.try_consume_file(
+        with self.get_consumer(
             filename,
-            override_filename=override_filename,
-        )
+            DocumentMetadataOverrides(filename=override_filename),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self.assertIsNotNone(document)
 
         self.assertEqual(document.title, "Statement for November")
 
         self._assert_first_last_send_progress()
 
     def testOverrideTitle(self):
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_title="Override Title",
-        )
+            DocumentMetadataOverrides(title="Override Title"),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self.assertIsNotNone(document)
+
         self.assertEqual(document.title, "Override Title")
         self._assert_first_last_send_progress()
 
     def testOverrideTitleInvalidPlaceholders(self):
         with self.assertLogs("paperless.consumer", level="ERROR") as cm:
-            document = self.consumer.try_consume_file(
+
+            with self.get_consumer(
                 self.get_test_file(),
-                override_title="Override {correspondent]",
-            )
+                DocumentMetadataOverrides(title="Override {correspondent]"),
+            ) as consumer:
+                consumer.run()
+
+                document = Document.objects.first()
+
+            self.assertIsNotNone(document)
+
             self.assertEqual(document.title, "sample")
             expected_str = "Error occurred parsing title override 'Override {correspondent]', falling back to original"
             self.assertIn(expected_str, cm.output[0])
@@ -436,30 +459,44 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
     def testOverrideCorrespondent(self):
         c = Correspondent.objects.create(name="test")
 
-        document = self.consumer.try_consume_file(
+        with self.get_consumer(
             self.get_test_file(),
-            override_correspondent_id=c.pk,
-        )
+            DocumentMetadataOverrides(correspondent_id=c.pk),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self.assertIsNotNone(document)
+
         self.assertEqual(document.correspondent.id, c.id)
         self._assert_first_last_send_progress()
 
     def testOverrideDocumentType(self):
         dt = DocumentType.objects.create(name="test")
 
-        document = self.consumer.try_consume_file(
+        with self.get_consumer(
             self.get_test_file(),
-            override_document_type_id=dt.pk,
-        )
+            DocumentMetadataOverrides(document_type_id=dt.pk),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         self.assertEqual(document.document_type.id, dt.id)
         self._assert_first_last_send_progress()
 
     def testOverrideStoragePath(self):
         sp = StoragePath.objects.create(name="test")
 
-        document = self.consumer.try_consume_file(
+        with self.get_consumer(
             self.get_test_file(),
-            override_storage_path_id=sp.pk,
-        )
+            DocumentMetadataOverrides(storage_path_id=sp.pk),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         self.assertEqual(document.storage_path.id, sp.id)
         self._assert_first_last_send_progress()
 
@@ -467,10 +504,14 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         t1 = Tag.objects.create(name="t1")
         t2 = Tag.objects.create(name="t2")
         t3 = Tag.objects.create(name="t3")
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_tag_ids=[t1.id, t3.id],
-        )
+            DocumentMetadataOverrides(tag_ids=[t1.id, t3.id]),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertIn(t1, document.tags.all())
         self.assertNotIn(t2, document.tags.all())
@@ -487,10 +528,14 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             name="Custom Field 3",
             data_type="url",
         )
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_custom_field_ids=[cf1.id, cf3.id],
-        )
+            DocumentMetadataOverrides(custom_field_ids=[cf1.id, cf3.id]),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         fields_used = [
             field_instance.field for field_instance in document.custom_fields.all()
@@ -501,10 +546,15 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self._assert_first_last_send_progress()
 
     def testOverrideAsn(self):
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_asn=123,
-        )
+            DocumentMetadataOverrides(asn=123),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         self.assertEqual(document.archive_serial_number, 123)
         self._assert_first_last_send_progress()
 
@@ -512,33 +562,51 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         c = Correspondent.objects.create(name="Correspondent Name")
         dt = DocumentType.objects.create(name="DocType Name")
 
-        document = self.consumer.try_consume_file(
+        with self.get_consumer(
             self.get_test_file(),
-            override_correspondent_id=c.pk,
-            override_document_type_id=dt.pk,
-            override_title="{correspondent}{document_type} {added_month}-{added_year_short}",
-        )
+            DocumentMetadataOverrides(
+                correspondent_id=c.pk,
+                document_type_id=dt.pk,
+                title="{correspondent}{document_type} {added_month}-{added_year_short}",
+            ),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         now = timezone.now()
         self.assertEqual(document.title, f"{c.name}{dt.name} {now.strftime('%m-%y')}")
         self._assert_first_last_send_progress()
 
     def testOverrideOwner(self):
         testuser = User.objects.create(username="testuser")
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_owner_id=testuser.pk,
-        )
+            DocumentMetadataOverrides(owner_id=testuser.pk),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         self.assertEqual(document.owner, testuser)
         self._assert_first_last_send_progress()
 
     def testOverridePermissions(self):
         testuser = User.objects.create(username="testuser")
         testgroup = Group.objects.create(name="testgroup")
-        document = self.consumer.try_consume_file(
+
+        with self.get_consumer(
             self.get_test_file(),
-            override_view_users=[testuser.pk],
-            override_view_groups=[testgroup.pk],
-        )
+            DocumentMetadataOverrides(
+                view_users=[testuser.pk],
+                view_groups=[testgroup.pk],
+            ),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
         user_checker = ObjectPermissionChecker(testuser)
         self.assertTrue(user_checker.has_perm("view_document", document))
         group_checker = ObjectPermissionChecker(testgroup)
@@ -546,53 +614,48 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self._assert_first_last_send_progress()
 
     def testNotAFile(self):
-        self.assertRaisesMessage(
-            ConsumerError,
-            "File not found",
-            self.consumer.try_consume_file,
-            "non-existing-file",
-        )
 
+        with self.get_consumer(Path("non-existing-file")) as consumer:
+            with self.assertRaisesMessage(ConsumerError, "File not found"):
+                consumer.run()
         self._assert_first_last_send_progress(last_status="FAILED")
 
     def testDuplicates1(self):
-        self.consumer.try_consume_file(self.get_test_file())
+        with self.get_consumer(self.get_test_file()) as consumer:
+            consumer.run()
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "It is a duplicate",
-            self.consumer.try_consume_file,
-            self.get_test_file(),
-        )
+        with self.get_consumer(self.get_test_file()) as consumer:
+            with self.assertRaisesMessage(ConsumerError, "It is a duplicate"):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
     def testDuplicates2(self):
-        self.consumer.try_consume_file(self.get_test_file())
+        with self.get_consumer(self.get_test_file()) as consumer:
+            consumer.run()
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "It is a duplicate",
-            self.consumer.try_consume_file,
-            self.get_test_archive_file(),
-        )
+        with self.get_consumer(self.get_test_archive_file()) as consumer:
+            with self.assertRaisesMessage(ConsumerError, "It is a duplicate"):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
     def testDuplicates3(self):
-        self.consumer.try_consume_file(self.get_test_archive_file())
-        self.consumer.try_consume_file(self.get_test_file())
+        with self.get_consumer(self.get_test_archive_file()) as consumer:
+            consumer.run()
+        with self.get_consumer(self.get_test_file()) as consumer:
+            consumer.run()
 
     @mock.patch("documents.parsers.document_consumer_declaration.send")
     def testNoParsers(self, m):
         m.return_value = []
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "sample.pdf: Unsupported mime type application/pdf",
-            self.consumer.try_consume_file,
-            self.get_test_file(),
-        )
+        with self.get_consumer(self.get_test_file()) as consumer:
+            with self.assertRaisesMessage(
+                ConsumerError,
+                "sample.pdf: Unsupported mime type application/pdf",
+            ):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
@@ -609,12 +672,12 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             ),
         ]
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "sample.pdf: Error occurred while consuming document sample.pdf: Does not compute.",
-            self.consumer.try_consume_file,
-            self.get_test_file(),
-        )
+        with self.get_consumer(self.get_test_file()) as consumer:
+            with self.assertRaisesMessage(
+                ConsumerError,
+                "sample.pdf: Error occurred while consuming document sample.pdf: Does not compute.",
+            ):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
@@ -631,26 +694,26 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             ),
         ]
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "sample.pdf: Unexpected error while consuming document sample.pdf: Generic exception.",
-            self.consumer.try_consume_file,
-            self.get_test_file(),
-        )
+        with self.get_consumer(self.get_test_file()) as consumer:
+            with self.assertRaisesMessage(
+                ConsumerError,
+                "sample.pdf: Unexpected error while consuming document sample.pdf: Generic exception.",
+            ):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
-    @mock.patch("documents.consumer.Consumer._write")
+    @mock.patch("documents.consumer.ConsumerPlugin._write")
     def testPostSaveError(self, m):
         filename = self.get_test_file()
         m.side_effect = OSError("NO.")
 
-        self.assertRaisesMessage(
-            ConsumerError,
-            "sample.pdf: The following error occurred while storing document sample.pdf after parsing: NO.",
-            self.consumer.try_consume_file,
-            filename,
-        )
+        with self.get_consumer(self.get_test_file()) as consumer:
+            with self.assertRaisesMessage(
+                ConsumerError,
+                "sample.pdf: The following error occurred while storing document sample.pdf after parsing: NO.",
+            ):
+                consumer.run()
 
         self._assert_first_last_send_progress(last_status="FAILED")
 
@@ -658,13 +721,18 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.assertIsFile(filename)
 
         # Database empty
-        self.assertEqual(len(Document.objects.all()), 0)
+        self.assertEqual(Document.objects.all().count(), 0)
 
     @override_settings(FILENAME_FORMAT="{correspondent}/{title}")
     def testFilenameHandling(self):
-        filename = self.get_test_file()
 
-        document = self.consumer.try_consume_file(filename, override_title="new docs")
+        with self.get_consumer(
+            self.get_test_file(),
+            DocumentMetadataOverrides(title="new docs"),
+        ) as consumer:
+            consumer.run()
+
+        document = Document.objects.first()
 
         self.assertEqual(document.title, "new docs")
         self.assertEqual(document.filename, "none/new docs.pdf")
@@ -684,11 +752,15 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         m.side_effect = lambda f, archive_filename=False: get_filename()
 
-        filename = self.get_test_file()
-
         Tag.objects.create(name="test", is_inbox_tag=True)
 
-        document = self.consumer.try_consume_file(filename, override_title="new docs")
+        with self.get_consumer(
+            self.get_test_file(),
+            DocumentMetadataOverrides(title="new docs"),
+        ) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(document.title, "new docs")
         self.assertIsNotNone(document.title)
@@ -715,7 +787,10 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         m.return_value.predict_document_type.return_value = dtype.pk
         m.return_value.predict_tags.return_value = [t1.pk]
 
-        document = self.consumer.try_consume_file(self.get_test_file())
+        with self.get_consumer(self.get_test_file()) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(document.correspondent, correspondent)
         self.assertEqual(document.document_type, dtype)
@@ -728,18 +803,24 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
     def test_delete_duplicate(self):
         dst = self.get_test_file()
         self.assertIsFile(dst)
-        doc = self.consumer.try_consume_file(dst)
+
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self._assert_first_last_send_progress()
 
         self.assertIsNotFile(dst)
-        self.assertIsNotNone(doc)
-
-        self._send_progress.reset_mock()
+        self.assertIsNotNone(document)
 
         dst = self.get_test_file()
         self.assertIsFile(dst)
-        self.assertRaises(ConsumerError, self.consumer.try_consume_file, dst)
+
+        with self.get_consumer(dst) as consumer:
+            with self.assertRaises(ConsumerError):
+                consumer.run()
+
         self.assertIsNotFile(dst)
         self._assert_first_last_send_progress(last_status="FAILED")
 
@@ -747,32 +828,44 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
     def test_no_delete_duplicate(self):
         dst = self.get_test_file()
         self.assertIsFile(dst)
-        doc = self.consumer.try_consume_file(dst)
+
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
+
+        self._assert_first_last_send_progress()
 
         self.assertIsNotFile(dst)
-        self.assertIsNotNone(doc)
+        self.assertIsNotNone(document)
 
         dst = self.get_test_file()
         self.assertIsFile(dst)
-        self.assertRaises(ConsumerError, self.consumer.try_consume_file, dst)
-        self.assertIsFile(dst)
 
+        with self.get_consumer(dst) as consumer:
+            with self.assertRaisesRegex(
+                ConsumerError,
+                r"sample\.pdf: Not consuming sample\.pdf: It is a duplicate of sample \(#\d+\)",
+            ):
+                consumer.run()
+
+        self.assertIsFile(dst)
         self._assert_first_last_send_progress(last_status="FAILED")
 
     @override_settings(FILENAME_FORMAT="{title}")
     @mock.patch("documents.parsers.document_consumer_declaration.send")
     def test_similar_filenames(self, m):
         shutil.copy(
-            os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
-            os.path.join(settings.CONSUMPTION_DIR, "simple.pdf"),
+            Path(__file__).parent / "samples" / "simple.pdf",
+            settings.CONSUMPTION_DIR / "simple.pdf",
         )
         shutil.copy(
-            os.path.join(os.path.dirname(__file__), "samples", "simple.png"),
-            os.path.join(settings.CONSUMPTION_DIR, "simple.png"),
+            Path(__file__).parent / "samples" / "simple.png",
+            settings.CONSUMPTION_DIR / "simple.png",
         )
         shutil.copy(
-            os.path.join(os.path.dirname(__file__), "samples", "simple-noalpha.png"),
-            os.path.join(settings.CONSUMPTION_DIR, "simple.png.pdf"),
+            Path(__file__).parent / "samples" / "simple-noalpha.png",
+            settings.CONSUMPTION_DIR / "simple.png.pdf",
         )
         m.return_value = [
             (
@@ -784,20 +877,28 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
                 },
             ),
         ]
-        doc1 = self.consumer.try_consume_file(
-            os.path.join(settings.CONSUMPTION_DIR, "simple.png"),
-        )
-        doc2 = self.consumer.try_consume_file(
-            os.path.join(settings.CONSUMPTION_DIR, "simple.pdf"),
-        )
-        doc3 = self.consumer.try_consume_file(
-            os.path.join(settings.CONSUMPTION_DIR, "simple.png.pdf"),
-        )
+
+        with self.get_consumer(settings.CONSUMPTION_DIR / "simple.png") as consumer:
+            consumer.run()
+
+            doc1 = Document.objects.filter(pk=1).first()
+
+        with self.get_consumer(settings.CONSUMPTION_DIR / "simple.pdf") as consumer:
+            consumer.run()
+
+            doc2 = Document.objects.filter(pk=2).first()
+
+        with self.get_consumer(settings.CONSUMPTION_DIR / "simple.png.pdf") as consumer:
+            consumer.run()
+
+            doc3 = Document.objects.filter(pk=3).first()
 
         self.assertEqual(doc1.filename, "simple.png")
         self.assertEqual(doc1.archive_filename, "simple.pdf")
+
         self.assertEqual(doc2.filename, "simple.pdf")
         self.assertEqual(doc2.archive_filename, "simple_01.pdf")
+
         self.assertEqual(doc3.filename, "simple.png.pdf")
         self.assertEqual(doc3.archive_filename, "simple.png.pdf")
 
@@ -805,17 +906,10 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
 
 @mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
-class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
+class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
     def setUp(self):
         super().setUp()
 
-        # this prevents websocket message reports during testing.
-        patcher = mock.patch("documents.consumer.Consumer._send_progress")
-        self._send_progress = patcher.start()
-        self.addCleanup(patcher.stop)
-
-        self.consumer = Consumer()
-
     def test_consume_date_from_content(self):
         """
         GIVEN:
@@ -824,17 +918,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
         THEN:
             - Should parse the date from the file content
         """
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "originals",
-            "0000005.pdf",
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000005.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "sample.pdf")
+        dst = self.dirs.scratch_dir / "sample.pdf"
         shutil.copy(src, dst)
 
-        document = self.consumer.try_consume_file(dst)
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(
             document.created,
@@ -851,17 +948,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
         THEN:
             - Should parse the date from the filename
         """
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "originals",
-            "0000005.pdf",
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000005.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "Scan - 2022-02-01.pdf")
+        dst = self.dirs.scratch_dir / "Scan - 2022-02-01.pdf"
         shutil.copy(src, dst)
 
-        document = self.consumer.try_consume_file(dst)
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(
             document.created,
@@ -878,17 +978,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
         THEN:
             - Should parse the date from the content
         """
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "originals",
-            "0000005.pdf",
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000005.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "Scan - 2022-02-01.pdf")
+        dst = self.dirs.scratch_dir / "Scan - 2022-02-01.pdf"
         shutil.copy(src, dst)
 
-        document = self.consumer.try_consume_file(dst)
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(
             document.created,
@@ -907,17 +1010,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
         THEN:
             - Should parse the date from the filename
         """
-        src = os.path.join(
-            os.path.dirname(__file__),
-            "samples",
-            "documents",
-            "originals",
-            "0000006.pdf",
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000006.pdf"
         )
-        dst = os.path.join(self.dirs.scratch_dir, "0000006.pdf")
+        dst = self.dirs.scratch_dir / "0000006.pdf"
         shutil.copy(src, dst)
 
-        document = self.consumer.try_consume_file(dst)
+        with self.get_consumer(dst) as consumer:
+            consumer.run()
+
+            document = Document.objects.first()
 
         self.assertEqual(
             document.created,
@@ -925,58 +1031,57 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
         )
 
 
-class PreConsumeTestCase(TestCase):
+class PreConsumeTestCase(DirectoriesMixin, GetConsumerMixin, TestCase):
     def setUp(self) -> None:
-        # this prevents websocket message reports during testing.
-        patcher = mock.patch("documents.consumer.Consumer._send_progress")
-        self._send_progress = patcher.start()
-        self.addCleanup(patcher.stop)
-
-        return super().setUp()
+        super().setUp()
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000005.pdf"
+        )
+        self.test_file = self.dirs.scratch_dir / "sample.pdf"
+        shutil.copy(src, self.test_file)
 
     @mock.patch("documents.consumer.run_subprocess")
     @override_settings(PRE_CONSUME_SCRIPT=None)
     def test_no_pre_consume_script(self, m):
-        c = Consumer()
-        c.working_copy = "path-to-file"
-        c.run_pre_consume_script()
-        m.assert_not_called()
+        with self.get_consumer(self.test_file) as c:
+            c.run()
+            m.assert_not_called()
 
     @mock.patch("documents.consumer.run_subprocess")
-    @mock.patch("documents.consumer.Consumer._send_progress")
     @override_settings(PRE_CONSUME_SCRIPT="does-not-exist")
-    def test_pre_consume_script_not_found(self, m, m2):
-        c = Consumer()
-        c.filename = "somefile.pdf"
-        c.working_copy = "path-to-file"
-        self.assertRaises(ConsumerError, c.run_pre_consume_script)
+    def test_pre_consume_script_not_found(self, m):
+        with self.get_consumer(self.test_file) as c:
+
+            self.assertRaises(ConsumerError, c.run)
+            m.assert_not_called()
 
     @mock.patch("documents.consumer.run_subprocess")
     def test_pre_consume_script(self, m):
         with tempfile.NamedTemporaryFile() as script:
             with override_settings(PRE_CONSUME_SCRIPT=script.name):
-                c = Consumer()
-                c.original_path = "path-to-file"
-                c.working_copy = "/tmp/somewhere/path-to-file"
-                c.task_id = str(uuid.uuid4())
-                c.run_pre_consume_script()
+                with self.get_consumer(self.test_file) as c:
+                    c.run()
 
-                m.assert_called_once()
+                    m.assert_called_once()
 
-                args, _ = m.call_args
+                    args, _ = m.call_args
 
-                command = args[0]
-                environment = args[1]
+                    command = args[0]
+                    environment = args[1]
 
-                self.assertEqual(command[0], script.name)
-                self.assertEqual(command[1], "path-to-file")
+                    self.assertEqual(command[0], script.name)
+                    self.assertEqual(command[1], str(self.test_file))
 
-                subset = {
-                    "DOCUMENT_SOURCE_PATH": c.original_path,
-                    "DOCUMENT_WORKING_PATH": c.working_copy,
-                    "TASK_ID": c.task_id,
-                }
-                self.assertDictEqual(environment, {**environment, **subset})
+                    subset = {
+                        "DOCUMENT_SOURCE_PATH": str(c.input_doc.original_file),
+                        "DOCUMENT_WORKING_PATH": str(c.working_copy),
+                        "TASK_ID": c.task_id,
+                    }
+                    self.assertDictEqual(environment, {**environment, **subset})
 
     def test_script_with_output(self):
         """
@@ -1000,10 +1105,8 @@ class PreConsumeTestCase(TestCase):
 
             with override_settings(PRE_CONSUME_SCRIPT=script.name):
                 with self.assertLogs("paperless.consumer", level="INFO") as cm:
-                    c = Consumer()
-                    c.working_copy = "path-to-file"
-
-                    c.run_pre_consume_script()
+                    with self.get_consumer(self.test_file) as c:
+                        c.run()
                     self.assertIn(
                         "INFO:paperless.consumer:This message goes to stdout",
                         cm.output,
@@ -1033,22 +1136,25 @@ class PreConsumeTestCase(TestCase):
             os.chmod(script.name, st.st_mode | stat.S_IEXEC)
 
             with override_settings(PRE_CONSUME_SCRIPT=script.name):
-                c = Consumer()
-                c.working_copy = "path-to-file"
-                self.assertRaises(
-                    ConsumerError,
-                    c.run_pre_consume_script,
-                )
+                with self.get_consumer(self.test_file) as c:
+                    self.assertRaises(
+                        ConsumerError,
+                        c.run,
+                    )
 
 
-class PostConsumeTestCase(TestCase):
+class PostConsumeTestCase(DirectoriesMixin, GetConsumerMixin, TestCase):
     def setUp(self) -> None:
-        # this prevents websocket message reports during testing.
-        patcher = mock.patch("documents.consumer.Consumer._send_progress")
-        self._send_progress = patcher.start()
-        self.addCleanup(patcher.stop)
-
-        return super().setUp()
+        super().setUp()
+        src = (
+            Path(__file__).parent
+            / "samples"
+            / "documents"
+            / "originals"
+            / "0000005.pdf"
+        )
+        self.test_file = self.dirs.scratch_dir / "sample.pdf"
+        shutil.copy(src, self.test_file)
 
     @mock.patch("documents.consumer.run_subprocess")
     @override_settings(POST_CONSUME_SCRIPT=None)
@@ -1059,21 +1165,20 @@ class PostConsumeTestCase(TestCase):
         doc.tags.add(tag1)
         doc.tags.add(tag2)
 
-        Consumer().run_post_consume_script(doc)
-
+        with self.get_consumer(self.test_file) as consumer:
+            consumer.run_post_consume_script(doc)
         m.assert_not_called()
 
     @override_settings(POST_CONSUME_SCRIPT="does-not-exist")
-    @mock.patch("documents.consumer.Consumer._send_progress")
-    def test_post_consume_script_not_found(self, m):
+    def test_post_consume_script_not_found(self):
         doc = Document.objects.create(title="Test", mime_type="application/pdf")
-        c = Consumer()
-        c.filename = "somefile.pdf"
-        self.assertRaises(
-            ConsumerError,
-            c.run_post_consume_script,
-            doc,
-        )
+
+        with self.get_consumer(self.test_file) as consumer:
+            with self.assertRaisesMessage(
+                ConsumerError,
+                "sample.pdf: Configured post-consume script does-not-exist does not exist",
+            ):
+                consumer.run_post_consume_script(doc)
 
     @mock.patch("documents.consumer.run_subprocess")
     def test_post_consume_script_simple(self, m):
@@ -1081,7 +1186,8 @@ class PostConsumeTestCase(TestCase):
             with override_settings(POST_CONSUME_SCRIPT=script.name):
                 doc = Document.objects.create(title="Test", mime_type="application/pdf")
 
-                Consumer().run_post_consume_script(doc)
+                with self.get_consumer(self.test_file) as consumer:
+                    consumer.run_post_consume_script(doc)
 
                 m.assert_called_once()
 
@@ -1100,9 +1206,8 @@ class PostConsumeTestCase(TestCase):
                 doc.tags.add(tag1)
                 doc.tags.add(tag2)
 
-                consumer = Consumer()
-                consumer.task_id = str(uuid.uuid4())
-                consumer.run_post_consume_script(doc)
+                with self.get_consumer(self.test_file) as consumer:
+                    consumer.run_post_consume_script(doc)
 
                 m.assert_called_once()
 
@@ -1149,8 +1254,11 @@ class PostConsumeTestCase(TestCase):
             os.chmod(script.name, st.st_mode | stat.S_IEXEC)
 
             with override_settings(POST_CONSUME_SCRIPT=script.name):
-                c = Consumer()
+
                 doc = Document.objects.create(title="Test", mime_type="application/pdf")
-                c.path = "path-to-file"
-                with self.assertRaises(ConsumerError):
-                    c.run_post_consume_script(doc)
+                with self.get_consumer(self.test_file) as consumer:
+                    with self.assertRaisesRegex(
+                        ConsumerError,
+                        r"sample\.pdf: Error while executing post-consume script: Command '\[.*\]' returned non-zero exit status \d+\.",
+                    ):
+                        consumer.run_post_consume_script(doc)
index c665944910d04dbd736e163c555c8243f0614af3..64cd7be4847b7ac07d3947e06b60b6fcdb3a2d41 100644 (file)
@@ -46,7 +46,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         with mock.patch(
             "documents.tasks.ProgressManager",
             DummyProgressManager,
-        ), mock.patch("documents.consumer.async_to_sync"):
+        ):
             msg = tasks.consume_file(
                 ConsumableDocument(
                     source=DocumentSource.ConsumeFolder,
index 509a8e54dca7cd0f26aaf5988667afe05c5e14de..1dfb9e47a716f04ea83bf6df9e333e9b7e25dd2f 100644 (file)
@@ -1,3 +1,4 @@
+import shutil
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -88,8 +89,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
 
         return super().setUp()
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_match(self, m):
+    def test_workflow_match(self):
         """
         GIVEN:
             - Existing workflow
@@ -102,7 +102,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
             type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
             sources=f"{DocumentSource.ApiUpload},{DocumentSource.ConsumeFolder},{DocumentSource.MailFetch}",
             filter_filename="*simple*",
-            filter_path="*/samples/*",
+            filter_path=f"*/{self.dirs.scratch_dir.parts[-1]}/*",
         )
         action = WorkflowAction.objects.create(
             assign_title="Doc from {correspondent}",
@@ -133,7 +133,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         self.assertEqual(trigger.__str__(), "WorkflowTrigger 1")
         self.assertEqual(action.__str__(), "WorkflowAction 1")
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
@@ -144,26 +147,53 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertEqual(overrides["override_correspondent_id"], self.c.pk)
-                self.assertEqual(overrides["override_document_type_id"], self.dt.pk)
+
+                document = Document.objects.first()
+                self.assertEqual(document.correspondent, self.c)
+                self.assertEqual(document.document_type, self.dt)
+                self.assertEqual(list(document.tags.all()), [self.t1, self.t2, self.t3])
+                self.assertEqual(document.storage_path, self.sp)
+                self.assertEqual(document.owner, self.user2)
                 self.assertEqual(
-                    overrides["override_tag_ids"],
-                    [self.t1.pk, self.t2.pk, self.t3.pk],
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["view_document"],
+                        ),
+                    ),
+                    [self.user3],
                 )
-                self.assertEqual(overrides["override_storage_path_id"], self.sp.pk)
-                self.assertEqual(overrides["override_owner_id"], self.user2.pk)
-                self.assertEqual(overrides["override_view_users"], [self.user3.pk])
-                self.assertEqual(overrides["override_view_groups"], [self.group1.pk])
-                self.assertEqual(overrides["override_change_users"], [self.user3.pk])
-                self.assertEqual(overrides["override_change_groups"], [self.group1.pk])
                 self.assertEqual(
-                    overrides["override_title"],
-                    "Doc from {correspondent}",
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group1],
                 )
                 self.assertEqual(
-                    overrides["override_custom_field_ids"],
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["change_document"],
+                        ),
+                    ),
+                    [self.user3],
+                )
+                self.assertEqual(
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group1],
+                )
+                self.assertEqual(
+                    document.title,
+                    f"Doc from {self.c.name}",
+                )
+                self.assertEqual(
+                    list(document.custom_fields.all().values_list("field", flat=True)),
                     [self.cf1.pk, self.cf2.pk],
                 )
 
@@ -171,8 +201,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         expected_str = f"Document matched {trigger} from {w}"
         self.assertIn(expected_str, info)
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_match_mailrule(self, m):
+    def test_workflow_match_mailrule(self):
         """
         GIVEN:
             - Existing workflow
@@ -211,7 +240,11 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
+
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
                 tasks.consume_file(
@@ -222,31 +255,55 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertEqual(overrides["override_correspondent_id"], self.c.pk)
-                self.assertEqual(overrides["override_document_type_id"], self.dt.pk)
+                document = Document.objects.first()
+                self.assertEqual(document.correspondent, self.c)
+                self.assertEqual(document.document_type, self.dt)
+                self.assertEqual(list(document.tags.all()), [self.t1, self.t2, self.t3])
+                self.assertEqual(document.storage_path, self.sp)
+                self.assertEqual(document.owner, self.user2)
+                self.assertEqual(
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["view_document"],
+                        ),
+                    ),
+                    [self.user3],
+                )
                 self.assertEqual(
-                    overrides["override_tag_ids"],
-                    [self.t1.pk, self.t2.pk, self.t3.pk],
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group1],
                 )
-                self.assertEqual(overrides["override_storage_path_id"], self.sp.pk)
-                self.assertEqual(overrides["override_owner_id"], self.user2.pk)
-                self.assertEqual(overrides["override_view_users"], [self.user3.pk])
-                self.assertEqual(overrides["override_view_groups"], [self.group1.pk])
-                self.assertEqual(overrides["override_change_users"], [self.user3.pk])
-                self.assertEqual(overrides["override_change_groups"], [self.group1.pk])
                 self.assertEqual(
-                    overrides["override_title"],
-                    "Doc from {correspondent}",
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["change_document"],
+                        ),
+                    ),
+                    [self.user3],
+                )
+                self.assertEqual(
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group1],
+                )
+                self.assertEqual(
+                    document.title,
+                    f"Doc from {self.c.name}",
                 )
-
         info = cm.output[0]
         expected_str = f"Document matched {trigger} from {w}"
         self.assertIn(expected_str, info)
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_match_multiple(self, m):
+    def test_workflow_match_multiple(self):
         """
         GIVEN:
             - Multiple existing workflow
@@ -259,7 +316,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         trigger1 = WorkflowTrigger.objects.create(
             type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
             sources=f"{DocumentSource.ApiUpload},{DocumentSource.ConsumeFolder},{DocumentSource.MailFetch}",
-            filter_path="*/samples/*",
+            filter_path=f"*/{self.dirs.scratch_dir.parts[-1]}/*",
         )
         action1 = WorkflowAction.objects.create(
             assign_title="Doc from {correspondent}",
@@ -301,7 +358,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w2.actions.add(action2)
         w2.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
@@ -312,21 +372,25 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
+                document = Document.objects.first()
                 # template 1
-                self.assertEqual(overrides["override_document_type_id"], self.dt.pk)
+                self.assertEqual(document.document_type, self.dt)
                 # template 2
-                self.assertEqual(overrides["override_correspondent_id"], self.c2.pk)
-                self.assertEqual(overrides["override_storage_path_id"], self.sp.pk)
+                self.assertEqual(document.correspondent, self.c2)
+                self.assertEqual(document.storage_path, self.sp)
                 # template 1 & 2
                 self.assertEqual(
-                    overrides["override_tag_ids"],
-                    [self.t1.pk, self.t2.pk, self.t3.pk],
+                    list(document.tags.all()),
+                    [self.t1, self.t2, self.t3],
                 )
                 self.assertEqual(
-                    overrides["override_view_users"],
-                    [self.user2.pk, self.user3.pk],
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["view_document"],
+                        ),
+                    ),
+                    [self.user2, self.user3],
                 )
 
         expected_str = f"Document matched {trigger1} from {w1}"
@@ -334,8 +398,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         expected_str = f"Document matched {trigger2} from {w2}"
         self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_fnmatch_path(self, m):
+    def test_workflow_fnmatch_path(self):
         """
         GIVEN:
             - Existing workflow
@@ -348,7 +411,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         trigger = WorkflowTrigger.objects.create(
             type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
             sources=f"{DocumentSource.ApiUpload},{DocumentSource.ConsumeFolder},{DocumentSource.MailFetch}",
-            filter_path="*sample*",
+            filter_path=f"*{self.dirs.scratch_dir.parts[-1]}*",
         )
         action = WorkflowAction.objects.create(
             assign_title="Doc fnmatch title",
@@ -363,7 +426,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="DEBUG") as cm:
@@ -374,15 +440,13 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertEqual(overrides["override_title"], "Doc fnmatch title")
+                document = Document.objects.first()
+                self.assertEqual(document.title, "Doc fnmatch title")
 
         expected_str = f"Document matched {trigger} from {w}"
         self.assertIn(expected_str, cm.output[0])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_no_match_filename(self, m):
+    def test_workflow_no_match_filename(self):
         """
         GIVEN:
             - Existing workflow
@@ -414,7 +478,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="DEBUG") as cm:
@@ -425,26 +492,36 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
-                self.assertIsNone(overrides["override_tag_ids"])
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertIsNone(overrides["override_view_users"])
-                self.assertIsNone(overrides["override_view_groups"])
-                self.assertIsNone(overrides["override_change_users"])
-                self.assertIsNone(overrides["override_change_groups"])
-                self.assertIsNone(overrides["override_title"])
+                document = Document.objects.first()
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(document.tags.all().count(), 0)
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["view_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(get_groups_with_perms(document).count(), 0)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["change_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(get_groups_with_perms(document).count(), 0)
+                self.assertEqual(document.title, "simple")
 
         expected_str = f"Document did not match {w}"
         self.assertIn(expected_str, cm.output[0])
         expected_str = f"Document filename {test_file.name} does not match"
         self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_no_match_path(self, m):
+    def test_workflow_no_match_path(self):
         """
         GIVEN:
             - Existing workflow
@@ -475,7 +552,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="DEBUG") as cm:
@@ -486,26 +566,46 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
-                self.assertIsNone(overrides["override_tag_ids"])
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertIsNone(overrides["override_view_users"])
-                self.assertIsNone(overrides["override_view_groups"])
-                self.assertIsNone(overrides["override_change_users"])
-                self.assertIsNone(overrides["override_change_groups"])
-                self.assertIsNone(overrides["override_title"])
+                document = Document.objects.first()
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(document.tags.all().count(), 0)
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["view_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["change_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(document.title, "simple")
 
         expected_str = f"Document did not match {w}"
         self.assertIn(expected_str, cm.output[0])
         expected_str = f"Document path {test_file} does not match"
         self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_no_match_mail_rule(self, m):
+    def test_workflow_no_match_mail_rule(self):
         """
         GIVEN:
             - Existing workflow
@@ -536,7 +636,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="DEBUG") as cm:
@@ -548,26 +651,46 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
-                self.assertIsNone(overrides["override_tag_ids"])
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertIsNone(overrides["override_view_users"])
-                self.assertIsNone(overrides["override_view_groups"])
-                self.assertIsNone(overrides["override_change_users"])
-                self.assertIsNone(overrides["override_change_groups"])
-                self.assertIsNone(overrides["override_title"])
+                document = Document.objects.first()
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(document.tags.all().count(), 0)
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["view_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["change_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(document.title, "simple")
 
         expected_str = f"Document did not match {w}"
         self.assertIn(expected_str, cm.output[0])
         expected_str = "Document mail rule 99 !="
         self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_no_match_source(self, m):
+    def test_workflow_no_match_source(self):
         """
         GIVEN:
             - Existing workflow
@@ -598,7 +721,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="DEBUG") as cm:
@@ -609,18 +735,39 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
-                self.assertIsNone(overrides["override_tag_ids"])
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertIsNone(overrides["override_view_users"])
-                self.assertIsNone(overrides["override_view_groups"])
-                self.assertIsNone(overrides["override_change_users"])
-                self.assertIsNone(overrides["override_change_groups"])
-                self.assertIsNone(overrides["override_title"])
+                document = Document.objects.first()
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(document.tags.all().count(), 0)
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["view_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["change_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(document.title, "simple")
 
         expected_str = f"Document did not match {w}"
         self.assertIn(expected_str, cm.output[0])
@@ -662,8 +809,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
             expected_str = f"No matching triggers with type {WorkflowTrigger.WorkflowTriggerType.DOCUMENT_ADDED} found"
             self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_workflow_repeat_custom_fields(self, m):
+    def test_workflow_repeat_custom_fields(self):
         """
         GIVEN:
             - Existing workflows which assign the same custom field
@@ -693,7 +839,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action1, action2)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
@@ -704,10 +853,9 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
+                document = Document.objects.first()
                 self.assertEqual(
-                    overrides["override_custom_field_ids"],
+                    list(document.custom_fields.all().values_list("field", flat=True)),
                     [self.cf1.pk],
                 )
 
@@ -1369,8 +1517,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         group_perms: QuerySet = get_groups_with_perms(doc)
         self.assertNotIn(self.group1, group_perms)
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_removal_action_document_consumed(self, m):
+    def test_removal_action_document_consumed(self):
         """
         GIVEN:
             - Workflow with assignment and removal actions
@@ -1429,7 +1576,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action2)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
@@ -1440,26 +1590,57 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
+
+                document = Document.objects.first()
+
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(
+                    list(document.tags.all()),
+                    [self.t2, self.t3],
+                )
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
                 self.assertEqual(
-                    overrides["override_tag_ids"],
-                    [self.t2.pk, self.t3.pk],
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["view_document"],
+                        ),
+                    ),
+                    [self.user2],
                 )
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertEqual(overrides["override_view_users"], [self.user2.pk])
-                self.assertEqual(overrides["override_view_groups"], [self.group2.pk])
-                self.assertEqual(overrides["override_change_users"], [self.user2.pk])
-                self.assertEqual(overrides["override_change_groups"], [self.group2.pk])
                 self.assertEqual(
-                    overrides["override_title"],
-                    "Doc from {correspondent}",
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group2],
+                )
+                self.assertEqual(
+                    list(
+                        get_users_with_perms(
+                            document,
+                            only_with_perms_in=["change_document"],
+                        ),
+                    ),
+                    [self.user2],
+                )
+                self.assertEqual(
+                    list(
+                        get_groups_with_perms(
+                            document,
+                        ),
+                    ),
+                    [self.group2],
                 )
                 self.assertEqual(
-                    overrides["override_custom_field_ids"],
+                    document.title,
+                    "Doc from None",
+                )
+                self.assertEqual(
+                    list(document.custom_fields.all().values_list("field", flat=True)),
                     [self.cf2.pk],
                 )
 
@@ -1467,8 +1648,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         expected_str = f"Document matched {trigger} from {w}"
         self.assertIn(expected_str, info)
 
-    @mock.patch("documents.consumer.Consumer.try_consume_file")
-    def test_removal_action_document_consumed_removeall(self, m):
+    def test_removal_action_document_consumed_remove_all(self):
         """
         GIVEN:
             - Workflow with assignment and removal actions with remove all fields set
@@ -1519,7 +1699,10 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
         w.actions.add(action2)
         w.save()
 
-        test_file = self.SAMPLE_DIR / "simple.pdf"
+        test_file = shutil.copy(
+            self.SAMPLE_DIR / "simple.pdf",
+            self.dirs.scratch_dir / "simple.pdf",
+        )
 
         with mock.patch("documents.tasks.ProgressManager", DummyProgressManager):
             with self.assertLogs("paperless.matching", level="INFO") as cm:
@@ -1530,23 +1713,46 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
                     ),
                     None,
                 )
-                m.assert_called_once()
-                _, overrides = m.call_args
-                self.assertIsNone(overrides["override_correspondent_id"])
-                self.assertIsNone(overrides["override_document_type_id"])
+                document = Document.objects.first()
+                self.assertIsNone(document.correspondent)
+                self.assertIsNone(document.document_type)
+                self.assertEqual(document.tags.all().count(), 0)
+
+                self.assertIsNone(document.storage_path)
+                self.assertIsNone(document.owner)
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["view_document"],
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
+                )
+                self.assertEqual(
+                    get_users_with_perms(
+                        document,
+                        only_with_perms_in=["change_document"],
+                    ).count(),
+                    0,
+                )
                 self.assertEqual(
-                    overrides["override_tag_ids"],
-                    [],
+                    get_groups_with_perms(
+                        document,
+                    ).count(),
+                    0,
                 )
-                self.assertIsNone(overrides["override_storage_path_id"])
-                self.assertIsNone(overrides["override_owner_id"])
-                self.assertEqual(overrides["override_view_users"], [])
-                self.assertEqual(overrides["override_view_groups"], [])
-                self.assertEqual(overrides["override_change_users"], [])
-                self.assertEqual(overrides["override_change_groups"], [])
                 self.assertEqual(
-                    overrides["override_custom_field_ids"],
-                    [],
+                    document.custom_fields.all()
+                    .values_list(
+                        "field",
+                    )
+                    .count(),
+                    0,
                 )
 
         info = cm.output[0]
index abdd270179319fedc4182d6eb188b3b434451790..fb4fa9f07f07e35d4eb0f14629ae73e49f872830 100644 (file)
@@ -3,6 +3,7 @@ import tempfile
 import time
 import warnings
 from collections import namedtuple
+from collections.abc import Generator
 from collections.abc import Iterator
 from contextlib import contextmanager
 from os import PathLike
@@ -21,8 +22,10 @@ from django.db.migrations.executor import MigrationExecutor
 from django.test import TransactionTestCase
 from django.test import override_settings
 
+from documents.consumer import ConsumerPlugin
 from documents.data_models import ConsumableDocument
 from documents.data_models import DocumentMetadataOverrides
+from documents.data_models import DocumentSource
 from documents.parsers import ParseError
 from documents.plugins.helpers import ProgressStatusOptions
 
@@ -326,6 +329,30 @@ class SampleDirMixin:
     BARCODE_SAMPLE_DIR = SAMPLE_DIR / "barcodes"
 
 
+class GetConsumerMixin:
+    @contextmanager
+    def get_consumer(
+        self,
+        filepath: Path,
+        overrides: Union[DocumentMetadataOverrides, None] = None,
+        source: DocumentSource = DocumentSource.ConsumeFolder,
+    ) -> Generator[ConsumerPlugin, None, None]:
+        # Store this for verification
+        self.status = DummyProgressManager(filepath.name, None)
+        reader = ConsumerPlugin(
+            ConsumableDocument(source, original_file=filepath),
+            overrides or DocumentMetadataOverrides(),
+            self.status,  # type: ignore
+            self.dirs.scratch_dir,
+            "task-id",
+        )
+        reader.setup()
+        try:
+            yield reader
+        finally:
+            reader.cleanup()
+
+
 class DummyProgressManager:
     """
     A dummy handler for progress management that doesn't actually try to
index 64af7c9b7dd49459e6573e870d40309597849fd3..72d3321cd75a67e4185a56621f15f8327a8af9a2 100644 (file)
@@ -7,7 +7,6 @@ import re
 import tempfile
 from os import PathLike
 from pathlib import Path
-from platform import machine
 from typing import Final
 from typing import Optional
 from typing import Union
@@ -112,7 +111,7 @@ def __get_list(
         return []
 
 
-def _parse_redis_url(env_redis: Optional[str]) -> tuple[str]:
+def _parse_redis_url(env_redis: Optional[str]) -> tuple[str, str]:
     """
     Gets the Redis information from the environment or a default and handles
     converting from incompatible django_channels and celery formats.
@@ -371,10 +370,7 @@ ASGI_APPLICATION = "paperless.asgi.application"
 STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/")
 WHITENOISE_STATIC_PREFIX = "/static/"
 
-if machine().lower() == "aarch64":  # pragma: no cover
-    _static_backend = "django.contrib.staticfiles.storage.StaticFilesStorage"
-else:
-    _static_backend = "whitenoise.storage.CompressedStaticFilesStorage"
+_static_backend = "django.contrib.staticfiles.storage.StaticFilesStorage"
 
 STORAGES = {
     "staticfiles": {
index a41b52a6ef5b65bacad50b62299266a5a4d234bc..7d41590ad257381de915fe9f94237eba3c0b40fd 100644 (file)
@@ -425,6 +425,10 @@ class MailAccountHandler(LoggingMixin):
 
     logging_name = "paperless_mail"
 
+    def __init__(self) -> None:
+        super().__init__()
+        self.renew_logging_group()
+
     def _correspondent_from_name(self, name: str) -> Optional[Correspondent]:
         try:
             return Correspondent.objects.get_or_create(name=name)[0]