]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Updates how task_args and task_kwargs are parsed, adds testing to cover everything...
authorTrenton H <holmes.trenton@gmail.com>
Wed, 28 Sep 2022 16:09:51 +0000 (09:09 -0700)
committerTrenton H <holmes.trenton@gmail.com>
Wed, 28 Sep 2022 17:40:55 +0000 (10:40 -0700)
src/documents/serialisers.py
src/documents/tests/test_api.py

index 172992de434347a3089efadb7da5568e1cd33f4e..cd99c43cd17ee69cfda10d173744acc4ba2615ac 100644 (file)
@@ -1,8 +1,14 @@
 import datetime
-import json
 import math
-import os
 import re
+from ast import literal_eval
+from asyncio.log import logger
+from pathlib import Path
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+from celery import states
 
 try:
     import zoneinfo
@@ -646,16 +652,21 @@ class TasksViewSerializer(serializers.ModelSerializer):
 
     def get_result(self, obj):
         result = ""
-        if hasattr(obj, "attempted_task") and obj.attempted_task:
+        if (
+            hasattr(obj, "attempted_task")
+            and obj.attempted_task
+            and obj.attempted_task.result
+        ):
             try:
-                result_json = json.loads(obj.attempted_task.result)
-            except Exception:
-                pass
-
-            if result_json and "exc_message" in result_json:
-                result = result_json["exc_message"]
-            else:
-                result = obj.attempted_task.result.strip('"')
+                result: str = obj.attempted_task.result
+                if "exc_message" in result:
+                    # This is a dict in this case
+                    result: Dict = literal_eval(result)
+                    # This is a list, grab the first item (most recent)
+                    result = result["exc_message"][0]
+            except Exception as e:  # pragma: no cover
+                # Extra security if something is malformed
+                logger.warn(f"Error getting task result: {e}", exc_info=True)
         return result
 
     status = serializers.SerializerMethodField()
@@ -704,26 +715,25 @@ class TasksViewSerializer(serializers.ModelSerializer):
         result = ""
         if hasattr(obj, "attempted_task") and obj.attempted_task:
             try:
-                # We have to make this a valid JSON object string
-                kwargs_json = json.loads(
-                    obj.attempted_task.task_kwargs.strip('"')
-                    .replace("'", '"')
-                    .replace("None", '""'),
-                )
-            except Exception:
-                pass
-
-            if kwargs_json and "override_filename" in kwargs_json:
-                result = kwargs_json["override_filename"]
-            else:
-                filepath = (
-                    obj.attempted_task.task_args.replace('"', "")
-                    .replace("'", "")
-                    .replace("(", "")
-                    .replace(")", "")
-                    .replace(",", "")
-                )
-                result = os.path.split(filepath)[1]
+                task_kwargs: Optional[str] = obj.attempted_task.task_kwargs
+                # Try the override filename first (this is a webui created task?)
+                if task_kwargs is not None:
+                    # It's a string, string of a dict.  Who knows why...
+                    kwargs = literal_eval(literal_eval(task_kwargs))
+                    if "override_filename" in kwargs:
+                        result = kwargs["override_filename"]
+
+                # Nothing was found, report the task first argument
+                if not len(result):
+                    # There are always some arguments to the consume
+                    task_args: Tuple = literal_eval(
+                        literal_eval(obj.attempted_task.task_args),
+                    )
+                    filepath = Path(task_args[0])
+                    result = filepath.name
+            except Exception as e:  # pragma: no cover
+                # Extra security if something is malformed
+                logger.warn(f"Error getting task result: {e}", exc_info=True)
 
         return result
 
@@ -735,7 +745,8 @@ class TasksViewSerializer(serializers.ModelSerializer):
         if (
             hasattr(obj, "attempted_task")
             and obj.attempted_task
-            and obj.attempted_task.status == "SUCCESS"
+            and obj.attempted_task.result
+            and obj.attempted_task.status == states.SUCCESS
         ):
             try:
                 result = re.search(regexp, obj.attempted_task.result).group(1)
index 89e34050195850b80b18008a0ed5480e2f4c823d..ec89a19e840226a69364cbc8d28eb64568965e88 100644 (file)
@@ -2831,6 +2831,14 @@ class TestTasks(APITestCase):
         self.assertEqual(returned_task2["task_name"], result2.task_name)
 
     def test_acknowledge_tasks(self):
+        """
+        GIVEN:
+            - Attempted celery tasks
+        WHEN:
+            - API call is made to get mark task as acknowledged
+        THEN:
+            - Task is marked as acknowledged
+        """
         result1 = TaskResult.objects.create(
             task_id=str(uuid.uuid4()),
             task_name="documents.tasks.some_task",
@@ -2849,3 +2857,119 @@ class TestTasks(APITestCase):
 
         response = self.client.get(self.ENDPOINT)
         self.assertEqual(len(response.data), 0)
+
+    def test_task_result_no_error(self):
+        """
+        GIVEN:
+            - A celery task completed without error
+        WHEN:
+            - API call is made to get tasks
+        THEN:
+            - The returned data includes the task result
+        """
+        result1 = TaskResult.objects.create(
+            task_id=str(uuid.uuid4()),
+            task_name="documents.tasks.some_task",
+            status=celery.states.SUCCESS,
+            result="Success. New document id 1 created",
+        )
+        _ = PaperlessTask.objects.create(attempted_task=result1)
+
+        response = self.client.get(self.ENDPOINT)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(len(response.data), 1)
+
+        returned_data = response.data[0]
+
+        self.assertEqual(returned_data["result"], "Success. New document id 1 created")
+        self.assertEqual(returned_data["related_document"], "1")
+
+    def test_task_result_with_error(self):
+        """
+        GIVEN:
+            - A celery task completed with an exception
+        WHEN:
+            - API call is made to get tasks
+        THEN:
+            - The returned result is the exception info
+        """
+        result1 = TaskResult.objects.create(
+            task_id=str(uuid.uuid4()),
+            task_name="documents.tasks.some_task",
+            status=celery.states.SUCCESS,
+            result={
+                "exc_type": "ConsumerError",
+                "exc_message": ["test.pdf: Not consuming test.pdf: It is a duplicate."],
+                "exc_module": "documents.consumer",
+            },
+        )
+        _ = PaperlessTask.objects.create(attempted_task=result1)
+
+        response = self.client.get(self.ENDPOINT)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(len(response.data), 1)
+
+        returned_data = response.data[0]
+
+        self.assertEqual(
+            returned_data["result"],
+            "test.pdf: Not consuming test.pdf: It is a duplicate.",
+        )
+
+    def test_task_name_webui(self):
+        """
+        GIVEN:
+            - Attempted celery task
+            - Task was created through the webui
+        WHEN:
+            - API call is made to get tasks
+        THEN:
+            - Returned data include the filename
+        """
+        result1 = TaskResult.objects.create(
+            task_id=str(uuid.uuid4()),
+            task_name="documents.tasks.some_task",
+            status=celery.states.SUCCESS,
+            task_args="\"('/tmp/paperless/paperless-upload-5iq7skzc',)\"",
+            task_kwargs="\"{'override_filename': 'test.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': '466e8fe7-7193-4698-9fff-72f0340e2082', 'override_created': None}\"",
+        )
+        _ = PaperlessTask.objects.create(attempted_task=result1)
+
+        response = self.client.get(self.ENDPOINT)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(len(response.data), 1)
+
+        returned_data = response.data[0]
+
+        self.assertEqual(returned_data["name"], "test.pdf")
+
+    def test_task_name_consume_folder(self):
+        """
+        GIVEN:
+            - Attempted celery task
+            - Task was created through the consume folder
+        WHEN:
+            - API call is made to get tasks
+        THEN:
+            - Returned data include the filename
+        """
+        result1 = TaskResult.objects.create(
+            task_id=str(uuid.uuid4()),
+            task_name="documents.tasks.some_task",
+            status=celery.states.SUCCESS,
+            task_args="\"('/consume/anothertest.pdf',)\"",
+            task_kwargs="\"{'override_tag_ids': None}\"",
+        )
+        _ = PaperlessTask.objects.create(attempted_task=result1)
+
+        response = self.client.get(self.ENDPOINT)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(len(response.data), 1)
+
+        returned_data = response.data[0]
+
+        self.assertEqual(returned_data["name"], "anothertest.pdf")