]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Create paperlesstasks for sanity, classifier
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 14 Feb 2025 01:46:05 +0000 (17:46 -0800)
committershamoon <4887959+shamoon@users.noreply.github.com>
Mon, 17 Feb 2025 16:19:11 +0000 (08:19 -0800)
[ci skip]

12 files changed:
src-ui/src/app/services/tasks.service.spec.ts
src-ui/src/app/services/tasks.service.ts
src/documents/filters.py
src/documents/management/commands/document_create_classifier.py
src/documents/management/commands/document_sanity_checker.py
src/documents/migrations/1063_paperlesstask_type.py [new file with mode: 0644]
src/documents/models.py
src/documents/sanity_checker.py
src/documents/serialisers.py
src/documents/signals/handlers.py
src/documents/tasks.py
src/documents/views.py

index fa84c9a1928bb2710859280639bcb2d38a8f2240..e161c15dd0887ac9748239a4171b1c1b63535542 100644 (file)
@@ -33,7 +33,7 @@ describe('TasksService', () => {
   it('calls tasks api endpoint on reload', () => {
     tasksService.reload()
     const req = httpTestingController.expectOne(
-      `${environment.apiBaseUrl}tasks/`
+      `${environment.apiBaseUrl}tasks/?type=file`
     )
     expect(req.request.method).toEqual('GET')
   })
@@ -41,7 +41,9 @@ describe('TasksService', () => {
   it('does not call tasks api endpoint on reload if already loading', () => {
     tasksService.loading = true
     tasksService.reload()
-    httpTestingController.expectNone(`${environment.apiBaseUrl}tasks/`)
+    httpTestingController.expectNone(
+      `${environment.apiBaseUrl}tasks/?type=file`
+    )
   })
 
   it('calls acknowledge_tasks api endpoint on dismiss and reloads', () => {
@@ -55,7 +57,9 @@ describe('TasksService', () => {
     })
     req.flush([])
     // reload is then called
-    httpTestingController.expectOne(`${environment.apiBaseUrl}tasks/`).flush([])
+    httpTestingController
+      .expectOne(`${environment.apiBaseUrl}tasks/?type=file`)
+      .flush([])
   })
 
   it('sorts tasks returned from api', () => {
@@ -106,7 +110,7 @@ describe('TasksService', () => {
     tasksService.reload()
 
     const req = httpTestingController.expectOne(
-      `${environment.apiBaseUrl}tasks/`
+      `${environment.apiBaseUrl}tasks/?type=file`
     )
 
     req.flush(mockTasks)
index c3c8f1d2b8ee94630bd25fe5c4265b5ec9c45ed0..4e7226bad461123f3523bb6cc742b03621593922 100644 (file)
@@ -54,7 +54,7 @@ export class TasksService {
     this.loading = true
 
     this.http
-      .get<PaperlessTask[]>(`${this.baseUrl}tasks/`)
+      .get<PaperlessTask[]>(`${this.baseUrl}tasks/?type=file`)
       .pipe(takeUntil(this.unsubscribeNotifer), first())
       .subscribe((r) => {
         this.fileTasks = r.filter((t) => t.type == PaperlessTaskType.File) // they're all File tasks, for now
index 1ce782ee6806c766219d4cb2c3df8f4b4d918087..086c889e0e7c9d465997ad15bcacfa28e8192812 100644 (file)
@@ -35,6 +35,7 @@ from documents.models import CustomFieldInstance
 from documents.models import Document
 from documents.models import DocumentType
 from documents.models import Log
+from documents.models import PaperlessTask
 from documents.models import ShareLink
 from documents.models import StoragePath
 from documents.models import Tag
@@ -770,6 +771,15 @@ class ShareLinkFilterSet(FilterSet):
         }
 
 
+class PaperlessTaskFilterSet(FilterSet):
+    class Meta:
+        model = PaperlessTask
+        fields = {
+            "type": ["exact"],
+            "status": ["exact"],
+        }
+
+
 class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
     """
     A filter backend that limits results to those where the requesting user
index f5df51aacc4a549757c767645eb1184ecc0d0359..f7903aac77f916e57f529de4513cc1f2b95d689e 100644 (file)
@@ -10,4 +10,4 @@ class Command(BaseCommand):
     )
 
     def handle(self, *args, **options):
-        train_classifier()
+        train_classifier(scheduled=False)
index 095781a9dbc0fb92ce9262e16e90d9731c394442..b634d4dc9ab5ae3a96f39db4c1f50f2b1530718c 100644 (file)
@@ -12,6 +12,6 @@ class Command(ProgressBarMixin, BaseCommand):
 
     def handle(self, *args, **options):
         self.handle_progress_bar_mixin(**options)
-        messages = check_sanity(progress=self.use_progress_bar)
+        messages = check_sanity(progress=self.use_progress_bar, scheduled=False)
 
         messages.log_messages()
diff --git a/src/documents/migrations/1063_paperlesstask_type.py b/src/documents/migrations/1063_paperlesstask_type.py
new file mode 100644 (file)
index 0000000..5a3c5d6
--- /dev/null
@@ -0,0 +1,28 @@
+# Generated by Django 5.1.6 on 2025-02-14 01:11
+
+from django.db import migrations
+from django.db import models
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ("documents", "1062_alter_savedviewfilterrule_rule_type"),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name="paperlesstask",
+            name="type",
+            field=models.CharField(
+                choices=[
+                    ("file", "File Task"),
+                    ("scheduled_task", "Scheduled Task"),
+                    ("manual_task", "Manual Task"),
+                ],
+                default="file",
+                help_text="The type of task that was run",
+                max_length=30,
+                verbose_name="Task Type",
+            ),
+        ),
+    ]
index 4f9d3cb0ebd2f9dd19f6cdb4459fb1f6317234fb..d6cb91e03f2618e852df8a290619e3c519033a38 100644 (file)
@@ -650,6 +650,11 @@ class PaperlessTask(ModelWithOwner):
     ALL_STATES = sorted(states.ALL_STATES)
     TASK_STATE_CHOICES = sorted(zip(ALL_STATES, ALL_STATES))
 
+    class TaskType(models.TextChoices):
+        FILE = ("file", _("File Task"))
+        SCHEDULED_TASK = ("scheduled_task", _("Scheduled Task"))
+        MANUAL_TASK = ("manual_task", _("Manual Task"))
+
     task_id = models.CharField(
         max_length=255,
         unique=True,
@@ -684,24 +689,28 @@ class PaperlessTask(ModelWithOwner):
         verbose_name=_("Task State"),
         help_text=_("Current state of the task being run"),
     )
+
     date_created = models.DateTimeField(
         null=True,
         default=timezone.now,
         verbose_name=_("Created DateTime"),
         help_text=_("Datetime field when the task result was created in UTC"),
     )
+
     date_started = models.DateTimeField(
         null=True,
         default=None,
         verbose_name=_("Started DateTime"),
         help_text=_("Datetime field when the task was started in UTC"),
     )
+
     date_done = models.DateTimeField(
         null=True,
         default=None,
         verbose_name=_("Completed DateTime"),
         help_text=_("Datetime field when the task was completed in UTC"),
     )
+
     result = models.TextField(
         null=True,
         default=None,
@@ -711,6 +720,14 @@ class PaperlessTask(ModelWithOwner):
         ),
     )
 
+    type = models.CharField(
+        max_length=30,
+        choices=TaskType.choices,
+        default=TaskType.FILE,
+        verbose_name=_("Task Type"),
+        help_text=_("The type of task that was run"),
+    )
+
     def __str__(self) -> str:
         return f"Task {self.task_id}"
 
index 28d2024e7238fec9999f1dbfb07f068d50236c95..cfb30e584c328639306eaff9a9580606f4f62b23 100644 (file)
@@ -1,13 +1,17 @@
 import hashlib
 import logging
+import uuid
 from collections import defaultdict
 from pathlib import Path
 from typing import Final
 
+from celery import states
 from django.conf import settings
+from django.utils import timezone
 from tqdm import tqdm
 
 from documents.models import Document
+from documents.models import PaperlessTask
 
 
 class SanityCheckMessages:
@@ -57,7 +61,17 @@ class SanityCheckFailedException(Exception):
     pass
 
 
-def check_sanity(*, progress=False) -> SanityCheckMessages:
+def check_sanity(*, progress=False, scheduled=True) -> SanityCheckMessages:
+    task = PaperlessTask.objects.create(
+        task_id=uuid.uuid4(),
+        type=PaperlessTask.TaskType.SCHEDULED_TASK
+        if scheduled
+        else PaperlessTask.TaskType.MANUAL_TASK,
+        task_name="check_sanity",
+        status=PaperlessTask.TASK_STATE_CHOICES.STARTED,
+        date_created=timezone.now(),
+        date_started=timezone.now(),
+    )
     messages = SanityCheckMessages()
 
     present_files = {
@@ -142,4 +156,8 @@ def check_sanity(*, progress=False) -> SanityCheckMessages:
     for extra_file in present_files:
         messages.warning(None, f"Orphaned file in media dir: {extra_file}")
 
+    task.status = states.SUCCESS if not messages.has_error else states.FAILED
+    # result is concatenated messages
+    task.result = str(messages)
+    task.date_done = timezone.now()
     return messages
index 84894bff1b615eb7fd6fcf3613039c038d1c8d06..fe4385e6d750d5ac8316b1d8e3a9c566c7d18221 100644 (file)
@@ -1700,12 +1700,6 @@ class TasksViewSerializer(OwnedObjectSerializer):
             "owner",
         )
 
-    type = serializers.SerializerMethodField()
-
-    def get_type(self, obj) -> str:
-        # just file tasks, for now
-        return "file"
-
     related_document = serializers.SerializerMethodField()
     created_doc_re = re.compile(r"New document id (\d+) created")
     duplicate_doc_re = re.compile(r"It is a duplicate of .* \(#(\d+)\)")
index 0079e5f8ccddf79b2782c5c39a89f85d29831732..1a821ea167f59db04c0e36f1bc5f84e4f1a42aa1 100644 (file)
@@ -1221,6 +1221,7 @@ def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs):
         user_id = overrides.owner_id if overrides else None
 
         PaperlessTask.objects.create(
+            type=PaperlessTask.TaskType.FILE,
             task_id=headers["id"],
             status=states.PENDING,
             task_file_name=task_file_name,
index d8539d1ab25f42374c66cfdbf92f698c1b3a8be5..83ac080a2c9864b27117afe6e480a37c40d44183 100644 (file)
@@ -9,6 +9,7 @@ from tempfile import TemporaryDirectory
 import tqdm
 from celery import Task
 from celery import shared_task
+from celery import states
 from django.conf import settings
 from django.contrib.contenttypes.models import ContentType
 from django.db import models
@@ -35,6 +36,7 @@ from documents.models import Correspondent
 from documents.models import CustomFieldInstance
 from documents.models import Document
 from documents.models import DocumentType
+from documents.models import PaperlessTask
 from documents.models import StoragePath
 from documents.models import Tag
 from documents.models import Workflow
@@ -74,19 +76,34 @@ def index_reindex(*, progress_bar_disable=False):
 
 
 @shared_task
-def train_classifier():
+def train_classifier(*, scheduled=True):
+    task = PaperlessTask.objects.create(
+        type=PaperlessTask.TaskType.SCHEDULED_TASK
+        if scheduled
+        else PaperlessTask.TaskType.MANUAL_TASK,
+        task_id=uuid.uuid4(),
+        task_name="train_classifier",
+        status=states.STARTED,
+        date_created=timezone.now(),
+        date_started=timezone.now(),
+    )
     if (
         not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
         and not DocumentType.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
         and not Correspondent.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
         and not StoragePath.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
     ):
-        logger.info("No automatic matching items, not training")
+        result = "No automatic matching items, not training"
+        logger.info(result)
         # Special case, items were once auto and trained, so remove the model
         # and prevent its use again
         if settings.MODEL_FILE.exists():
             logger.info(f"Removing {settings.MODEL_FILE} so it won't be used")
             settings.MODEL_FILE.unlink()
+        task.status = states.SUCCESS
+        task.result = result
+        task.date_done = timezone.now()
+        task.save()
         return
 
     classifier = load_classifier()
@@ -100,11 +117,19 @@ def train_classifier():
                 f"Saving updated classifier model to {settings.MODEL_FILE}...",
             )
             classifier.save()
+            task.status = states.SUCCESS
+            task.result = "Training completed successfully"
         else:
             logger.debug("Training data unchanged.")
+            task.status = states.SUCCESS
+            task.result = "Training data unchanged"
+
+        task.save(update_fields=["status", "result"])
 
     except Exception as e:
         logger.warning("Classifier error: " + str(e))
+        task.status = states.FAILED
+        task.result = str(e)
 
 
 @shared_task(bind=True)
index a856883f3c22eb69835df6b84c87259069ae1308..8193dc621e601aa32f40376ec7e8da261a0c56bb 100644 (file)
@@ -103,6 +103,7 @@ from documents.filters import DocumentsOrderingFilter
 from documents.filters import DocumentTypeFilterSet
 from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
 from documents.filters import ObjectOwnedPermissionsFilter
+from documents.filters import PaperlessTaskFilterSet
 from documents.filters import ShareLinkFilterSet
 from documents.filters import StoragePathFilterSet
 from documents.filters import TagFilterSet
@@ -2223,7 +2224,12 @@ class RemoteVersionView(GenericAPIView):
 class TasksViewSet(ReadOnlyModelViewSet):
     permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
     serializer_class = TasksViewSerializer
-    filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,)
+    filter_backends = (
+        DjangoFilterBackend,
+        OrderingFilter,
+        ObjectOwnedOrGrantedPermissionsFilter,
+    )
+    filterset_class = PaperlessTaskFilterSet
 
     def get_queryset(self):
         queryset = (