]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Added task id to pre/post consume script as env
authorAndré Heuer <mail@andre-heuer.de>
Sun, 20 Aug 2023 18:55:18 +0000 (20:55 +0200)
committershamoon <4887959+shamoon@users.noreply.github.com>
Wed, 30 Aug 2023 06:09:47 +0000 (23:09 -0700)
src/documents/consumer.py
src/documents/tasks.py
src/documents/tests/test_consumer.py

index 0ec6090c2ead434f576c8c08d98cdcb1262b9f94..fa756a3ce4f5c7f56ae4d6e996e66629fbc7f06b 100644 (file)
@@ -186,7 +186,7 @@ class Consumer(LoggingMixin):
                 f"Not consuming {self.filename}: Given ASN already exists!",
             )
 
-    def run_pre_consume_script(self):
+    def run_pre_consume_script(self, task_id):
         """
         If one is configured and exists, run the pre-consume script and
         handle its output and/or errors
@@ -209,6 +209,7 @@ class Consumer(LoggingMixin):
         script_env = os.environ.copy()
         script_env["DOCUMENT_SOURCE_PATH"] = original_file_path
         script_env["DOCUMENT_WORKING_PATH"] = working_file_path
+        script_env["TASK_ID"] = task_id
 
         try:
             completed_proc = run(
@@ -233,7 +234,7 @@ class Consumer(LoggingMixin):
                 exception=e,
             )
 
-    def run_post_consume_script(self, document: Document):
+    def run_post_consume_script(self, document: Document, task_id):
         """
         If one is configured and exists, run the pre-consume script and
         handle its output and/or errors
@@ -279,6 +280,7 @@ class Consumer(LoggingMixin):
             ",".join(document.tags.all().values_list("name", flat=True)),
         )
         script_env["DOCUMENT_ORIGINAL_FILENAME"] = str(document.original_filename)
+        script_env["TASK_ID"] = task_id
 
         try:
             completed_proc = run(
@@ -388,7 +390,7 @@ class Consumer(LoggingMixin):
             logging_group=self.logging_group,
         )
 
-        self.run_pre_consume_script()
+        self.run_pre_consume_script(task_id=self.task_id)
 
         def progress_callback(current_progress, max_progress):  # pragma: no cover
             # recalculate progress to be within 20 and 80
@@ -553,7 +555,7 @@ class Consumer(LoggingMixin):
             document_parser.cleanup()
             tempdir.cleanup()
 
-        self.run_post_consume_script(document)
+        self.run_post_consume_script(document, task_id=self.task_id)
 
         self.log.info(f"Document {document} consumption finished")
 
index 2dbc9d6eb8485e164e48885bdb1f32e373ae76df..7d73f852af6bc1942943f2475782b7985e10388e 100644 (file)
@@ -91,8 +91,9 @@ def train_classifier():
         logger.warning("Classifier error: " + str(e))
 
 
-@shared_task
+@shared_task(bind=True)
 def consume_file(
+    self,
     input_doc: ConsumableDocument,
     overrides: Optional[DocumentMetadataOverrides] = None,
 ):
@@ -163,6 +164,7 @@ def consume_file(
         override_created=overrides.created,
         override_asn=overrides.asn,
         override_owner_id=overrides.owner_id,
+        task_id=self.request.id,
     )
 
     if document:
index a8f427c37f36df7e9bbb3d759eeb9c063d2cd866..3285104b38fdb24a53d1381f38d3d59cc7af9aa8 100644 (file)
@@ -4,6 +4,7 @@ import re
 import shutil
 import stat
 import tempfile
+import uuid
 from unittest import mock
 from unittest.mock import MagicMock
 
@@ -802,7 +803,7 @@ class PreConsumeTestCase(TestCase):
     def test_no_pre_consume_script(self, m):
         c = Consumer()
         c.path = "path-to-file"
-        c.run_pre_consume_script()
+        c.run_pre_consume_script(str(uuid.uuid4()))
         m.assert_not_called()
 
     @mock.patch("documents.consumer.run")
@@ -812,7 +813,7 @@ class PreConsumeTestCase(TestCase):
         c = Consumer()
         c.filename = "somefile.pdf"
         c.path = "path-to-file"
-        self.assertRaises(ConsumerError, c.run_pre_consume_script)
+        self.assertRaises(ConsumerError, c.run_pre_consume_script, str(uuid.uuid4()))
 
     @mock.patch("documents.consumer.run")
     def test_pre_consume_script(self, m):
@@ -821,7 +822,8 @@ class PreConsumeTestCase(TestCase):
                 c = Consumer()
                 c.original_path = "path-to-file"
                 c.path = "/tmp/somewhere/path-to-file"
-                c.run_pre_consume_script()
+                task_id = str(uuid.uuid4())
+                c.run_pre_consume_script(task_id)
 
                 m.assert_called_once()
 
@@ -836,6 +838,7 @@ class PreConsumeTestCase(TestCase):
                 subset = {
                     "DOCUMENT_SOURCE_PATH": c.original_path,
                     "DOCUMENT_WORKING_PATH": c.path,
+                    "TASK_ID": task_id,
                 }
                 self.assertDictEqual(environment, {**environment, **subset})
 
@@ -864,7 +867,7 @@ class PreConsumeTestCase(TestCase):
                     c = Consumer()
                     c.path = "path-to-file"
 
-                    c.run_pre_consume_script()
+                    c.run_pre_consume_script(str(uuid.uuid4()))
                     self.assertIn(
                         "INFO:paperless.consumer:This message goes to stdout",
                         cm.output,
@@ -896,7 +899,11 @@ class PreConsumeTestCase(TestCase):
             with override_settings(PRE_CONSUME_SCRIPT=script.name):
                 c = Consumer()
                 c.path = "path-to-file"
-                self.assertRaises(ConsumerError, c.run_pre_consume_script)
+                self.assertRaises(
+                    ConsumerError,
+                    c.run_pre_consume_script,
+                    str(uuid.uuid4()),
+                )
 
 
 class PostConsumeTestCase(TestCase):
@@ -917,7 +924,7 @@ class PostConsumeTestCase(TestCase):
         doc.tags.add(tag1)
         doc.tags.add(tag2)
 
-        Consumer().run_post_consume_script(doc)
+        Consumer().run_post_consume_script(doc, str(uuid.uuid4()))
 
         m.assert_not_called()
 
@@ -927,7 +934,12 @@ class PostConsumeTestCase(TestCase):
         doc = Document.objects.create(title="Test", mime_type="application/pdf")
         c = Consumer()
         c.filename = "somefile.pdf"
-        self.assertRaises(ConsumerError, c.run_post_consume_script, doc)
+        self.assertRaises(
+            ConsumerError,
+            c.run_post_consume_script,
+            doc,
+            str(uuid.uuid4()),
+        )
 
     @mock.patch("documents.consumer.run")
     def test_post_consume_script_simple(self, m):
@@ -935,7 +947,7 @@ class PostConsumeTestCase(TestCase):
             with override_settings(POST_CONSUME_SCRIPT=script.name):
                 doc = Document.objects.create(title="Test", mime_type="application/pdf")
 
-                Consumer().run_post_consume_script(doc)
+                Consumer().run_post_consume_script(doc, str(uuid.uuid4()))
 
                 m.assert_called_once()
 
@@ -953,8 +965,9 @@ class PostConsumeTestCase(TestCase):
                 tag2 = Tag.objects.create(name="b")
                 doc.tags.add(tag1)
                 doc.tags.add(tag2)
+                task_id = str(uuid.uuid4())
 
-                Consumer().run_post_consume_script(doc)
+                Consumer().run_post_consume_script(doc, task_id)
 
                 m.assert_called_once()
 
@@ -976,6 +989,7 @@ class PostConsumeTestCase(TestCase):
                     "DOCUMENT_THUMBNAIL_URL": f"/api/documents/{doc.pk}/thumb/",
                     "DOCUMENT_CORRESPONDENT": "my_bank",
                     "DOCUMENT_TAGS": "a,b",
+                    "TASK_ID": task_id,
                 }
 
                 self.assertDictEqual(environment, {**environment, **subset})
@@ -1004,4 +1018,4 @@ class PostConsumeTestCase(TestCase):
                 doc = Document.objects.create(title="Test", mime_type="application/pdf")
                 c.path = "path-to-file"
                 with self.assertRaises(ConsumerError):
-                    c.run_post_consume_script(doc)
+                    c.run_post_consume_script(doc, str(uuid.uuid4()))