mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 18:22:40 -04:00 
			
		
		
		
	Create paperlesstasks for sanity, classifier
[ci skip]
This commit is contained in:
		
							parent
							
								
									de5f66b3a0
								
							
						
					
					
						commit
						f897447a65
					
				| @ -33,7 +33,7 @@ describe('TasksService', () => { | ||||
|   it('calls tasks api endpoint on reload', () => { | ||||
|     tasksService.reload() | ||||
|     const req = httpTestingController.expectOne( | ||||
|       `${environment.apiBaseUrl}tasks/` | ||||
|       `${environment.apiBaseUrl}tasks/?type=file` | ||||
|     ) | ||||
|     expect(req.request.method).toEqual('GET') | ||||
|   }) | ||||
| @ -41,7 +41,9 @@ describe('TasksService', () => { | ||||
|   it('does not call tasks api endpoint on reload if already loading', () => { | ||||
|     tasksService.loading = true | ||||
|     tasksService.reload() | ||||
|     httpTestingController.expectNone(`${environment.apiBaseUrl}tasks/`) | ||||
|     httpTestingController.expectNone( | ||||
|       `${environment.apiBaseUrl}tasks/?type=file` | ||||
|     ) | ||||
|   }) | ||||
| 
 | ||||
|   it('calls acknowledge_tasks api endpoint on dismiss and reloads', () => { | ||||
| @ -55,7 +57,9 @@ describe('TasksService', () => { | ||||
|     }) | ||||
|     req.flush([]) | ||||
|     // reload is then called
 | ||||
|     httpTestingController.expectOne(`${environment.apiBaseUrl}tasks/`).flush([]) | ||||
|     httpTestingController | ||||
|       .expectOne(`${environment.apiBaseUrl}tasks/?type=file`) | ||||
|       .flush([]) | ||||
|   }) | ||||
| 
 | ||||
|   it('sorts tasks returned from api', () => { | ||||
| @ -106,7 +110,7 @@ describe('TasksService', () => { | ||||
|     tasksService.reload() | ||||
| 
 | ||||
|     const req = httpTestingController.expectOne( | ||||
|       `${environment.apiBaseUrl}tasks/` | ||||
|       `${environment.apiBaseUrl}tasks/?type=file` | ||||
|     ) | ||||
| 
 | ||||
|     req.flush(mockTasks) | ||||
|  | ||||
| @ -54,7 +54,7 @@ export class TasksService { | ||||
|     this.loading = true | ||||
| 
 | ||||
|     this.http | ||||
|       .get<PaperlessTask[]>(`${this.baseUrl}tasks/`) | ||||
|       .get<PaperlessTask[]>(`${this.baseUrl}tasks/?type=file`) | ||||
|       .pipe(takeUntil(this.unsubscribeNotifer), first()) | ||||
|       .subscribe((r) => { | ||||
|         this.fileTasks = r.filter((t) => t.type == PaperlessTaskType.File) // they're all File tasks, for now
 | ||||
|  | ||||
| @ -35,6 +35,7 @@ from documents.models import CustomFieldInstance | ||||
| from documents.models import Document | ||||
| from documents.models import DocumentType | ||||
| from documents.models import Log | ||||
| from documents.models import PaperlessTask | ||||
| from documents.models import ShareLink | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| @ -770,6 +771,15 @@ class ShareLinkFilterSet(FilterSet): | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| class PaperlessTaskFilterSet(FilterSet): | ||||
|     class Meta: | ||||
|         model = PaperlessTask | ||||
|         fields = { | ||||
|             "type": ["exact"], | ||||
|             "status": ["exact"], | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter): | ||||
|     """ | ||||
|     A filter backend that limits results to those where the requesting user | ||||
|  | ||||
| @ -10,4 +10,4 @@ class Command(BaseCommand): | ||||
|     ) | ||||
| 
 | ||||
|     def handle(self, *args, **options): | ||||
|         train_classifier() | ||||
|         train_classifier(scheduled=False) | ||||
|  | ||||
| @ -12,6 +12,6 @@ class Command(ProgressBarMixin, BaseCommand): | ||||
| 
 | ||||
|     def handle(self, *args, **options): | ||||
|         self.handle_progress_bar_mixin(**options) | ||||
|         messages = check_sanity(progress=self.use_progress_bar) | ||||
|         messages = check_sanity(progress=self.use_progress_bar, scheduled=False) | ||||
| 
 | ||||
|         messages.log_messages() | ||||
|  | ||||
							
								
								
									
										28
									
								
								src/documents/migrations/1063_paperlesstask_type.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								src/documents/migrations/1063_paperlesstask_type.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | ||||
| # Generated by Django 5.1.6 on 2025-02-14 01:11 | ||||
| 
 | ||||
| from django.db import migrations | ||||
| from django.db import models | ||||
| 
 | ||||
| 
 | ||||
| class Migration(migrations.Migration): | ||||
|     dependencies = [ | ||||
|         ("documents", "1062_alter_savedviewfilterrule_rule_type"), | ||||
|     ] | ||||
| 
 | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="type", | ||||
|             field=models.CharField( | ||||
|                 choices=[ | ||||
|                     ("file", "File Task"), | ||||
|                     ("scheduled_task", "Scheduled Task"), | ||||
|                     ("manual_task", "Manual Task"), | ||||
|                 ], | ||||
|                 default="file", | ||||
|                 help_text="The type of task that was run", | ||||
|                 max_length=30, | ||||
|                 verbose_name="Task Type", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -650,6 +650,11 @@ class PaperlessTask(ModelWithOwner): | ||||
|     ALL_STATES = sorted(states.ALL_STATES) | ||||
|     TASK_STATE_CHOICES = sorted(zip(ALL_STATES, ALL_STATES)) | ||||
| 
 | ||||
|     class TaskType(models.TextChoices): | ||||
|         FILE = ("file", _("File Task")) | ||||
|         SCHEDULED_TASK = ("scheduled_task", _("Scheduled Task")) | ||||
|         MANUAL_TASK = ("manual_task", _("Manual Task")) | ||||
| 
 | ||||
|     task_id = models.CharField( | ||||
|         max_length=255, | ||||
|         unique=True, | ||||
| @ -684,24 +689,28 @@ class PaperlessTask(ModelWithOwner): | ||||
|         verbose_name=_("Task State"), | ||||
|         help_text=_("Current state of the task being run"), | ||||
|     ) | ||||
| 
 | ||||
|     date_created = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=timezone.now, | ||||
|         verbose_name=_("Created DateTime"), | ||||
|         help_text=_("Datetime field when the task result was created in UTC"), | ||||
|     ) | ||||
| 
 | ||||
|     date_started = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=None, | ||||
|         verbose_name=_("Started DateTime"), | ||||
|         help_text=_("Datetime field when the task was started in UTC"), | ||||
|     ) | ||||
| 
 | ||||
|     date_done = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=None, | ||||
|         verbose_name=_("Completed DateTime"), | ||||
|         help_text=_("Datetime field when the task was completed in UTC"), | ||||
|     ) | ||||
| 
 | ||||
|     result = models.TextField( | ||||
|         null=True, | ||||
|         default=None, | ||||
| @ -711,6 +720,14 @@ class PaperlessTask(ModelWithOwner): | ||||
|         ), | ||||
|     ) | ||||
| 
 | ||||
|     type = models.CharField( | ||||
|         max_length=30, | ||||
|         choices=TaskType.choices, | ||||
|         default=TaskType.FILE, | ||||
|         verbose_name=_("Task Type"), | ||||
|         help_text=_("The type of task that was run"), | ||||
|     ) | ||||
| 
 | ||||
|     def __str__(self) -> str: | ||||
|         return f"Task {self.task_id}" | ||||
| 
 | ||||
|  | ||||
| @ -1,13 +1,17 @@ | ||||
| import hashlib | ||||
| import logging | ||||
| import uuid | ||||
| from collections import defaultdict | ||||
| from pathlib import Path | ||||
| from typing import Final | ||||
| 
 | ||||
| from celery import states | ||||
| from django.conf import settings | ||||
| from django.utils import timezone | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| from documents.models import Document | ||||
| from documents.models import PaperlessTask | ||||
| 
 | ||||
| 
 | ||||
| class SanityCheckMessages: | ||||
| @ -57,7 +61,17 @@ class SanityCheckFailedException(Exception): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def check_sanity(*, progress=False) -> SanityCheckMessages: | ||||
| def check_sanity(*, progress=False, scheduled=True) -> SanityCheckMessages: | ||||
|     task = PaperlessTask.objects.create( | ||||
|         task_id=uuid.uuid4(), | ||||
|         type=PaperlessTask.TaskType.SCHEDULED_TASK | ||||
|         if scheduled | ||||
|         else PaperlessTask.TaskType.MANUAL_TASK, | ||||
|         task_name="check_sanity", | ||||
|         status=PaperlessTask.TASK_STATE_CHOICES.STARTED, | ||||
|         date_created=timezone.now(), | ||||
|         date_started=timezone.now(), | ||||
|     ) | ||||
|     messages = SanityCheckMessages() | ||||
| 
 | ||||
|     present_files = { | ||||
| @ -142,4 +156,8 @@ def check_sanity(*, progress=False) -> SanityCheckMessages: | ||||
|     for extra_file in present_files: | ||||
|         messages.warning(None, f"Orphaned file in media dir: {extra_file}") | ||||
| 
 | ||||
|     task.status = states.SUCCESS if not messages.has_error else states.FAILED | ||||
|     # result is concatenated messages | ||||
|     task.result = str(messages) | ||||
|     task.date_done = timezone.now() | ||||
|     return messages | ||||
|  | ||||
| @ -1700,12 +1700,6 @@ class TasksViewSerializer(OwnedObjectSerializer): | ||||
|             "owner", | ||||
|         ) | ||||
| 
 | ||||
|     type = serializers.SerializerMethodField() | ||||
| 
 | ||||
|     def get_type(self, obj) -> str: | ||||
|         # just file tasks, for now | ||||
|         return "file" | ||||
| 
 | ||||
|     related_document = serializers.SerializerMethodField() | ||||
|     created_doc_re = re.compile(r"New document id (\d+) created") | ||||
|     duplicate_doc_re = re.compile(r"It is a duplicate of .* \(#(\d+)\)") | ||||
|  | ||||
| @ -1221,6 +1221,7 @@ def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs): | ||||
|         user_id = overrides.owner_id if overrides else None | ||||
| 
 | ||||
|         PaperlessTask.objects.create( | ||||
|             type=PaperlessTask.TaskType.FILE, | ||||
|             task_id=headers["id"], | ||||
|             status=states.PENDING, | ||||
|             task_file_name=task_file_name, | ||||
|  | ||||
| @ -9,6 +9,7 @@ from tempfile import TemporaryDirectory | ||||
| import tqdm | ||||
| from celery import Task | ||||
| from celery import shared_task | ||||
| from celery import states | ||||
| from django.conf import settings | ||||
| from django.contrib.contenttypes.models import ContentType | ||||
| from django.db import models | ||||
| @ -35,6 +36,7 @@ from documents.models import Correspondent | ||||
| from documents.models import CustomFieldInstance | ||||
| from documents.models import Document | ||||
| from documents.models import DocumentType | ||||
| from documents.models import PaperlessTask | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| from documents.models import Workflow | ||||
| @ -74,19 +76,34 @@ def index_reindex(*, progress_bar_disable=False): | ||||
| 
 | ||||
| 
 | ||||
| @shared_task | ||||
| def train_classifier(): | ||||
| def train_classifier(*, scheduled=True): | ||||
|     task = PaperlessTask.objects.create( | ||||
|         type=PaperlessTask.TaskType.SCHEDULED_TASK | ||||
|         if scheduled | ||||
|         else PaperlessTask.TaskType.MANUAL_TASK, | ||||
|         task_id=uuid.uuid4(), | ||||
|         task_name="train_classifier", | ||||
|         status=states.STARTED, | ||||
|         date_created=timezone.now(), | ||||
|         date_started=timezone.now(), | ||||
|     ) | ||||
|     if ( | ||||
|         not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() | ||||
|         and not DocumentType.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() | ||||
|         and not Correspondent.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() | ||||
|         and not StoragePath.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() | ||||
|     ): | ||||
|         logger.info("No automatic matching items, not training") | ||||
|         result = "No automatic matching items, not training" | ||||
|         logger.info(result) | ||||
|         # Special case, items were once auto and trained, so remove the model | ||||
|         # and prevent its use again | ||||
|         if settings.MODEL_FILE.exists(): | ||||
|             logger.info(f"Removing {settings.MODEL_FILE} so it won't be used") | ||||
|             settings.MODEL_FILE.unlink() | ||||
|         task.status = states.SUCCESS | ||||
|         task.result = result | ||||
|         task.date_done = timezone.now() | ||||
|         task.save() | ||||
|         return | ||||
| 
 | ||||
|     classifier = load_classifier() | ||||
| @ -100,11 +117,19 @@ def train_classifier(): | ||||
|                 f"Saving updated classifier model to {settings.MODEL_FILE}...", | ||||
|             ) | ||||
|             classifier.save() | ||||
|             task.status = states.SUCCESS | ||||
|             task.result = "Training completed successfully" | ||||
|         else: | ||||
|             logger.debug("Training data unchanged.") | ||||
|             task.status = states.SUCCESS | ||||
|             task.result = "Training data unchanged" | ||||
| 
 | ||||
|         task.save(update_fields=["status", "result"]) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.warning("Classifier error: " + str(e)) | ||||
|         task.status = states.FAILED | ||||
|         task.result = str(e) | ||||
| 
 | ||||
| 
 | ||||
| @shared_task(bind=True) | ||||
|  | ||||
| @ -103,6 +103,7 @@ from documents.filters import DocumentsOrderingFilter | ||||
| from documents.filters import DocumentTypeFilterSet | ||||
| from documents.filters import ObjectOwnedOrGrantedPermissionsFilter | ||||
| from documents.filters import ObjectOwnedPermissionsFilter | ||||
| from documents.filters import PaperlessTaskFilterSet | ||||
| from documents.filters import ShareLinkFilterSet | ||||
| from documents.filters import StoragePathFilterSet | ||||
| from documents.filters import TagFilterSet | ||||
| @ -2223,7 +2224,12 @@ class RemoteVersionView(GenericAPIView): | ||||
| class TasksViewSet(ReadOnlyModelViewSet): | ||||
|     permission_classes = (IsAuthenticated, PaperlessObjectPermissions) | ||||
|     serializer_class = TasksViewSerializer | ||||
|     filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,) | ||||
|     filter_backends = ( | ||||
|         DjangoFilterBackend, | ||||
|         OrderingFilter, | ||||
|         ObjectOwnedOrGrantedPermissionsFilter, | ||||
|     ) | ||||
|     filterset_class = PaperlessTaskFilterSet | ||||
| 
 | ||||
|     def get_queryset(self): | ||||
|         queryset = ( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user