diff --git a/.envs/.local/.django b/.envs/.local/.django index e646e46eb..9f738881a 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -48,3 +48,7 @@ MINIO_BROWSER_REDIRECT_URL=http://minio:9001 DEFAULT_PROCESSING_SERVICE_NAME=Local Processing Service DEFAULT_PROCESSING_SERVICE_ENDPOINT=http://ml_backend:2000 # DEFAULT_PIPELINES_ENABLED=random,constant # When set to None, all pipelines will be enabled. + +CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672// +RABBITMQ_DEFAULT_USER=rabbituser +RABBITMQ_DEFAULT_PASS=rabbitpass diff --git a/ami/jobs/admin.py b/ami/jobs/admin.py index b5c921502..67236c342 100644 --- a/ami/jobs/admin.py +++ b/ami/jobs/admin.py @@ -4,7 +4,7 @@ from ami.main.admin import AdminBase -from .models import Job, get_job_type_by_inferred_key +from .models import Job, MLTaskRecord, get_job_type_by_inferred_key @admin.register(Job) @@ -54,3 +54,23 @@ def inferred_job_type(self, obj: Job) -> str: "progress", "result", ) + + +@admin.register(MLTaskRecord) +class MLTaskRecordAdmin(AdminBase): + """Admin panel example for ``MLTaskRecord`` model.""" + + list_display = ( + "job", + "task_id", + "task_name", + "status", + ) + + @admin.action() + def kill_task(self, request: HttpRequest, queryset: QuerySet[MLTaskRecord]) -> None: + for ml_task_record in queryset: + ml_task_record.kill_task() + self.message_user(request, f"Killed {queryset.count()} ML task(s).") + + actions = [kill_task] diff --git a/ami/jobs/migrations/0019_mltaskrecord.py b/ami/jobs/migrations/0019_mltaskrecord.py new file mode 100644 index 000000000..6c9741372 --- /dev/null +++ b/ami/jobs/migrations/0019_mltaskrecord.py @@ -0,0 +1,72 @@ +# Generated by Django 4.2.10 on 2025-08-10 22:17 + +import ami.ml.schemas +from django.db import migrations, models +import django.db.models.deletion +import django_pydantic_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0060_alter_sourceimagecollection_method"), + ("jobs", "0018_alter_job_job_type_key"), + ] + + operations = [ + migrations.CreateModel( + name="MLTaskRecord", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("task_id", models.CharField(max_length=255)), + ( + "task_name", + models.CharField( + choices=[ + ("process_pipeline_request", "process_pipeline_request"), + ("save_results", "save_results"), + ], + default="process_pipeline_request", + max_length=255, + ), + ), + ( + "status", + models.CharField( + choices=[("STARTED", "STARTED"), ("SUCCESS", "SUCCESS"), ("FAIL", "FAIL")], + default="STARTED", + max_length=255, + ), + ), + ("raw_results", models.JSONField(blank=True, default=dict, null=True)), + ("raw_traceback", models.TextField(blank=True, null=True)), + ( + "pipeline_request", + django_pydantic_field.fields.PydanticSchemaField( + blank=True, config=None, null=True, schema=ami.ml.schemas.PipelineRequest + ), + ), + ( + "pipeline_response", + django_pydantic_field.fields.PydanticSchemaField( + blank=True, config=None, null=True, schema=ami.ml.schemas.PipelineResultsResponse + ), + ), + ("num_captures", models.IntegerField(default=0, help_text="Same as number of source_images")), + ("num_detections", models.IntegerField(default=0)), + ("num_classifications", models.IntegerField(default=0)), + ("subtask_id", models.CharField(blank=True, max_length=255, null=True)), + ( + "job", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="ml_task_records", to="jobs.job" + ), + ), + ("source_images", models.ManyToManyField(related_name="ml_task_records", to="main.sourceimage")), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/ami/jobs/migrations/0020_alter_job_logs_alter_job_progress.py b/ami/jobs/migrations/0020_alter_job_logs_alter_job_progress.py new file mode 100644 index 000000000..755cf3307 --- /dev/null +++ b/ami/jobs/migrations/0020_alter_job_logs_alter_job_progress.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.10 on 2025-09-04 10:42 + +import ami.jobs.models +from django.db import migrations +import django_pydantic_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0019_mltaskrecord"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="logs", + field=django_pydantic_field.fields.PydanticSchemaField( + config=None, default=ami.jobs.models.JobLogs, schema=ami.jobs.models.JobLogs + ), + ), + migrations.AlterField( + model_name="job", + name="progress", + field=django_pydantic_field.fields.PydanticSchemaField( + config=None, default=ami.jobs.models.default_job_progress, schema=ami.jobs.models.JobProgress + ), + ), + ] diff --git a/ami/jobs/migrations/0021_remove_mltaskrecord_subtask_id_and_more.py b/ami/jobs/migrations/0021_remove_mltaskrecord_subtask_id_and_more.py new file mode 100644 index 000000000..6344d7b1f --- /dev/null +++ b/ami/jobs/migrations/0021_remove_mltaskrecord_subtask_id_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.10 on 2025-10-16 19:31 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0020_alter_job_logs_alter_job_progress"), + ] + + operations = [ + migrations.RemoveField( + model_name="mltaskrecord", + name="subtask_id", + ), + migrations.AlterField( + model_name="mltaskrecord", + name="status", + field=models.CharField( + choices=[("PENDING", "PENDING"), ("STARTED", "STARTED"), ("SUCCESS", "SUCCESS"), ("FAIL", "FAIL")], + default="STARTED", + max_length=255, + ), + ), + migrations.AlterField( + model_name="mltaskrecord", + name="task_id", + field=models.CharField(blank=True, max_length=255, null=True), + ), + ] diff --git a/ami/jobs/migrations/0022_job_last_checked.py b/ami/jobs/migrations/0022_job_last_checked.py new file mode 100644 index 000000000..9e6e608cd --- /dev/null +++ b/ami/jobs/migrations/0022_job_last_checked.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.10 on 2025-10-17 01:27 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0021_remove_mltaskrecord_subtask_id_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="last_checked", + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/ami/jobs/migrations/0023_alter_job_last_checked_alter_mltaskrecord_status.py b/ami/jobs/migrations/0023_alter_job_last_checked_alter_mltaskrecord_status.py new file mode 100644 index 000000000..217535d22 --- /dev/null +++ b/ami/jobs/migrations/0023_alter_job_last_checked_alter_mltaskrecord_status.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.10 on 2025-11-04 11:44 + +import datetime +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0022_job_last_checked"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="last_checked", + field=models.DateTimeField(blank=True, default=datetime.datetime.now, null=True), + ), + migrations.AlterField( + model_name="mltaskrecord", + name="status", + field=models.CharField( + choices=[ + ("PENDING", "PENDING"), + ("STARTED", "STARTED"), + ("SUCCESS", "SUCCESS"), + ("FAIL", "FAIL"), + ("REVOKED", "REVOKED"), + ], + default="STARTED", + max_length=255, + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index ac0078d76..c3f616c4d 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -1,4 +1,5 @@ import datetime +import json import logging import random import time @@ -8,6 +9,7 @@ import pydantic from celery import uuid from celery.result import AsyncResult +from django.core.serializers.json import DjangoJSONEncoder from django.db import models, transaction from django.utils.text import slugify from django_pydantic_field import SchemaField @@ -19,6 +21,9 @@ from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.post_processing.registry import get_postprocessing_task +from ami.ml.schemas import PipelineRequest, PipelineResultsResponse +from ami.ml.signals import get_worker_name, subscribe_celeryworker_to_pipeline_queues +from ami.ml.tasks import check_ml_job_status from ami.utils.schemas import OrderedEnum logger = logging.getLogger(__name__) @@ -72,7 +77,7 @@ def get_status_label(status: JobState, progress: float) -> str: if status in [JobState.CREATED, JobState.PENDING, JobState.RECEIVED]: return "Waiting to start" elif status in [JobState.STARTED, JobState.RETRY, JobState.SUCCESS]: - return f"{progress:.0%} complete" + return f"{progress: .0%} complete" else: return f"{status.name}" @@ -133,14 +138,14 @@ def get_stage(self, stage_key: str) -> JobProgressStageDetail: for stage in self.stages: if stage.key == stage_key: return stage - raise ValueError(f"Job stage with key '{stage_key}' not found in progress") + raise ValueError(f"Job stage with key '{stage_key}' not in progress") def get_stage_param(self, stage_key: str, param_key: str) -> ConfigurableStageParam: stage = self.get_stage(stage_key) for param in stage.params: if param.key == param_key: return param - raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") + raise ValueError(f"Job stage parameter with key '{param_key}' not in stage '{stage_key}'") def add_stage_param(self, stage_key: str, param_name: str, value: typing.Any = None) -> ConfigurableStageParam: stage = self.get_stage(stage_key) @@ -305,6 +310,13 @@ class JobType: # present_participle: str = "syncing" # past_participle: str = "synced" + @classmethod + def check_inprogress_subtasks(cls, job: "Job") -> bool | None: + """ + Check on the status of inprogress subtasks and update the job progress accordingly. + """ + pass + @classmethod def run(cls, job: "Job"): """ @@ -317,6 +329,355 @@ class MLJob(JobType): name = "ML pipeline" key = "ml" + @classmethod + def check_inprogress_subtasks(cls, job: "Job") -> bool: + """ + Check the status of the MLJob subtasks and update/create MLTaskRecords + based on if the subtasks fail/succeed. + This is the main function that keeps track of the MLJob's state and all of its subtasks. + + Returns True if all subtasks are completed. + """ + assert job.pipeline is not None, "Job pipeline is not set" + + inprogress_subtasks = ( + job.ml_task_records.exclude( + status__in=[ + MLSubtaskState.FAIL.name, + MLSubtaskState.SUCCESS.name, + ] + ) + .filter( + created_at__gte=job.started_at, + ) + .all() + ) + if len(inprogress_subtasks) == 0: + # No tasks inprogress, update the job progress + job.logger.info("No inprogress subtasks left.") + cls.update_job_progress(job) + return True + + save_results_task_record = { # if pipeline responses are produced, this task will be saved to the db + "job": job, + "task_id": None, + "status": MLSubtaskState.PENDING.name, # save result tasks are not started immediately + "task_name": MLSubtaskNames.save_results.name, + "num_captures": 0, + "num_detections": 0, + "num_classifications": 0, + } + save_results_to_save = [] # list of tuples (pipeline response, source images) + inprogress_subtasks_to_update = [] + for inprogress_subtask in inprogress_subtasks: + task_name = inprogress_subtask.task_name + task_id = inprogress_subtask.task_id + if not task_id: + assert ( + task_name == MLSubtaskNames.save_results.name + ), "Only save results tasks can have no task_id and be in a PENDING state." + # Ensure no other STARTED save_results tasks + if ( + job.ml_task_records.filter( + status=MLSubtaskState.STARTED.name, + task_name=MLSubtaskNames.save_results.name, + created_at__gte=job.started_at, + ).count() + == 0 + ): + assert ( + inprogress_subtask.pipeline_response is not None + ), "Save results task must have a pipeline response" + # Start the save results task now + save_results_task = job.pipeline.save_results_async( + results=inprogress_subtask.pipeline_response, job_id=job.pk + ) + inprogress_subtask.status = MLSubtaskState.STARTED.name + inprogress_subtask.task_id = save_results_task.id + task_id = save_results_task.id + inprogress_subtask.save() + job.logger.debug(f"Started save results task {inprogress_subtask.task_id}") + else: + job.logger.debug("A save results task is already in progress, will not start another one yet.") + continue + + task = AsyncResult(task_id) + if task.ready(): + inprogress_subtasks_to_update.append(inprogress_subtask) + inprogress_subtask.status = ( + MLSubtaskState.SUCCESS.name if task.successful() else MLSubtaskState.FAIL.name + ) + + if task.traceback: + job.logger.error(f"Subtask {task_name} ({task_id}) failed: {task.traceback}") + inprogress_subtask.status = MLSubtaskState.FAIL.name + inprogress_subtask.raw_traceback = task.traceback + continue + + results_dict = task.result + if task_name == MLSubtaskNames.process_pipeline_request.name: + try: + results = PipelineResultsResponse(**results_dict) + + except Exception as e: + error_msg = ( + f"Subtask {task_name} ({task_id}) failed since it received " + f"an invalid PipelineResultsResponse.\n" + f"Error: {e}\n" + f"Raw result: {results_dict}" + ) + job.logger.error(error_msg) + inprogress_subtask.status = MLSubtaskState.FAIL.name + inprogress_subtask.raw_traceback = error_msg + continue + + if results.errors: + error_detail = results.errors if isinstance(results.errors, str) else f"{results.errors}" + error_msg = ( + f"Subtask {task_name} ({task_id}) failed since the " + f"PipelineResultsResponse contains errors: {error_detail}" + ) + job.logger.error(error_msg) + inprogress_subtask.status = MLSubtaskState.FAIL.name + inprogress_subtask.raw_traceback = error_msg + continue + + num_captures = len(results.source_images) + num_detections = len(results.detections) + num_classifications = len([c for d in results.detections for c in d.classifications]) + # Update the process_pipeline_request MLTaskRecord + inprogress_subtask.raw_results = json.loads(json.dumps(results.dict(), cls=DjangoJSONEncoder)) + inprogress_subtask.num_captures = num_captures + inprogress_subtask.num_detections = num_detections + inprogress_subtask.num_classifications = num_classifications + + if results.source_images or results.detections: + save_results_to_save.append((results, inprogress_subtask.source_images.all())) + save_results_task_record["num_captures"] += num_captures + save_results_task_record["num_detections"] += num_detections + save_results_task_record["num_classifications"] += num_classifications + elif task_name == MLSubtaskNames.save_results.name: + pass + else: + raise Exception(f"Unexpected task_name: {task_name}") + + # To avoid long running jobs from taking a long time to update, bulk update every 10 tasks + # Bulk save the updated inprogress subtasks + if len(inprogress_subtasks_to_update) >= 10: + MLTaskRecord.objects.bulk_update( + inprogress_subtasks_to_update, + [ + "status", + "raw_traceback", + "raw_results", + "num_captures", + "num_detections", + "num_classifications", + ], + ) + + cls.update_job_progress(job) + + # Reset the lists + inprogress_subtasks_to_update = [] + + assert job.pipeline is not None, "Job pipeline is not set" + # submit a single save results task + if len(save_results_to_save) > 0: + created_task_record = MLTaskRecord.objects.create(**save_results_task_record) + for _, source_images in save_results_to_save: + created_task_record.source_images.add(*source_images) + pipeline_results = [t[0] for t in save_results_to_save] + combined_pipeline_results = ( + pipeline_results[0].combine_with(pipeline_results[1:]) + if len(pipeline_results) > 1 + else pipeline_results[0] + ) + created_task_record.pipeline_response = combined_pipeline_results + created_task_record.save() + + # Bulk save the remaining items + # Bulk save the updated inprogress subtasks + MLTaskRecord.objects.bulk_update( + inprogress_subtasks_to_update, + [ + "status", + "raw_traceback", + "raw_results", + "num_captures", + "num_detections", + "num_classifications", + ], + ) + + cls.update_job_progress(job) + + inprogress_subtasks = ( + job.ml_task_records.exclude( + status__in=[ + MLSubtaskState.FAIL.name, + MLSubtaskState.SUCCESS.name, + ] + ) + .filter( + created_at__gte=job.started_at, + ) + .all() + ) + total_subtasks = job.ml_task_records.all().count() + if inprogress_subtasks.count() > 0: + job.logger.info( + f"{inprogress_subtasks.count()} inprogress subtasks remaining out of {total_subtasks} total subtasks." + ) + inprogress_task_ids = [task.task_id for task in inprogress_subtasks] + job.logger.debug(f"Subtask ids: {inprogress_task_ids}") + return False + else: + job.logger.info("No inprogress subtasks left.") + return True + + @classmethod + def update_job_progress(cls, job: "Job"): + """ + Using the MLTaskRecords of a related Job, update the job progress. + This function only updates the UI's job status. No new data is created here. + """ + # At any time, we should have all process_pipeline_request in queue + # That is: len(inprogress_process_pipeline) + len(completed_process_pipeline) + # = total process_pipeline_request tasks + inprogress_process_pipeline = job.ml_task_records.filter( + status=MLSubtaskState.STARTED.name, + task_name=MLSubtaskNames.process_pipeline_request.name, + created_at__gte=job.started_at, + ) + completed_process_pipelines = job.ml_task_records.filter( + status__in=[MLSubtaskState.FAIL.name, MLSubtaskState.SUCCESS.name], + task_name=MLSubtaskNames.process_pipeline_request.name, + created_at__gte=job.started_at, + ) + + # Calculate process stage stats + inprogress_process_captures = sum([ml_task.num_captures for ml_task in inprogress_process_pipeline], 0) + completed_process_captures = sum([ml_task.num_captures for ml_task in completed_process_pipelines], 0) + failed_process_captures = sum( + [ + ml_task.num_captures + for ml_task in completed_process_pipelines + if ml_task.status != MLSubtaskState.SUCCESS.name + ], + 0, + ) + + # Update the process stage + if inprogress_process_pipeline.count() > 0: + job.progress.update_stage( + "process", + status=JobState.STARTED, + progress=completed_process_pipelines.count() + / (completed_process_pipelines.count() + inprogress_process_pipeline.count()), + processed=completed_process_captures, + remaining=inprogress_process_captures, + failed=failed_process_captures, + ) + else: + job.progress.update_stage( # @TODO: should we have a failure threshold of 50%? + "process", + status=JobState.FAILURE if failed_process_captures > 0 else JobState.SUCCESS, + progress=1, + processed=completed_process_captures, + remaining=inprogress_process_captures, + failed=failed_process_captures, + ) + + inprogress_save_results = job.ml_task_records.filter( + status__in=[ + MLSubtaskState.STARTED.name, + MLSubtaskState.PENDING.name, + ], + task_name=MLSubtaskNames.save_results.name, + created_at__gte=job.started_at, + ) + completed_save_results = job.ml_task_records.filter( + status__in=[MLSubtaskState.FAIL.name, MLSubtaskState.SUCCESS.name], + task_name=MLSubtaskNames.save_results.name, + created_at__gte=job.started_at, + ) + succeeded_save_results = job.ml_task_records.filter( + status=MLSubtaskState.SUCCESS.name, + task_name=MLSubtaskNames.save_results.name, + created_at__gte=job.started_at, + ) + + # Calculate results stage stats + failed_process_tasks = ( + True + if any([task_record.status != MLSubtaskState.SUCCESS.name for task_record in completed_process_pipelines]) + else False + ) + num_failed_save_tasks = sum( + [1 for ml_task in completed_save_results if ml_task.status != MLSubtaskState.SUCCESS.name], + 0, + ) + failed_save_tasks = num_failed_save_tasks > 0 + any_failed_tasks = failed_process_tasks or failed_save_tasks + + # only include captures/detections/classifications which we successfully saved + total_results_captures = sum([ml_task.num_captures for ml_task in succeeded_save_results], 0) + total_results_detections = sum([ml_task.num_detections for ml_task in succeeded_save_results], 0) + total_results_classifications = sum([ml_task.num_classifications for ml_task in succeeded_save_results], 0) + + # Update the results stage + if inprogress_save_results.count() > 0 or inprogress_process_pipeline.count() > 0: + job.progress.update_stage( + "results", + status=JobState.STARTED, + # Save results tasks may not have been submitted, or they may be in progress + # progress denominator is based on the total number of process_pipeline_request tasks + # 1:1 ratio between save_results and process_pipeline_request tasks + progress=completed_save_results.count() + / (completed_process_pipelines.count() + inprogress_process_pipeline.count()), + captures=total_results_captures, + detections=total_results_detections, + classifications=total_results_classifications, + failed=num_failed_save_tasks, + ) + else: + job.progress.update_stage( + "results", + status=JobState.FAILURE if failed_save_tasks else JobState.SUCCESS, + progress=1, + captures=total_results_captures, + detections=total_results_detections, + classifications=total_results_classifications, + failed=num_failed_save_tasks, + ) + + # The ML job is completed, log general job stats + job.update_status(JobState.FAILURE if any_failed_tasks else JobState.SUCCESS, save=False) + + if any_failed_tasks: + failed_save_task_ids = [ + completed_save_result.task_id + for completed_save_result in completed_save_results + if completed_save_result.status == MLSubtaskState.FAIL.name + ] + job.logger.error( + f"Failed save result task ids = {failed_save_task_ids}" + ) # TODO: more for dev debugging? + + failed_process_task_ids = [ + completed_process_pipeline.task_id + for completed_process_pipeline in completed_process_pipelines + if completed_process_pipeline.status == MLSubtaskState.FAIL.name + ] + job.logger.error( + f"Failed process task ids = {failed_process_task_ids}" + ) # TODO: more for dev debugging? + + job.finished_at = datetime.datetime.now() + + job.save() + @classmethod def run(cls, job: "Job"): """ @@ -327,10 +688,6 @@ def run(cls, job: "Job"): job.finished_at = None job.save() - # Keep track of sub-tasks for saving results, pair with batch number - save_tasks: list[tuple[int, AsyncResult]] = [] - save_tasks_completed: list[tuple[int, AsyncResult]] = [] - if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -401,116 +758,57 @@ def run(cls, job: "Job"): # End image collection stage job.save() - total_captures = 0 - total_detections = 0 - total_classifications = 0 - - config = job.pipeline.get_config(project_id=job.project.pk) - chunk_size = config.get("request_source_image_batch_size", 1) - chunks = [images[i : i + chunk_size] for i in range(0, image_count, chunk_size)] # noqa - request_failed_images = [] - job.logger.info(f"Processing {image_count} images in {len(chunks)} batches of up to {chunk_size}") - - for i, chunk in enumerate(chunks): - request_sent = time.time() - job.logger.info(f"Processing image batch {i+1} of {len(chunks)}") - try: - results = job.pipeline.process_images( - images=chunk, - job_id=job.pk, - project_id=job.project.pk, - ) - job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") - except Exception as e: - # Log error about image batch and continue - job.logger.error(f"Failed to process image batch {i+1}: {e}") - request_failed_images.extend([img.pk for img in chunk]) - else: - total_captures += len(results.source_images) - total_detections += len(results.detections) - total_classifications += len([c for d in results.detections for c in d.classifications]) - - if results.source_images or results.detections: - # @TODO add callback to report errors while saving results marking the job as failed - save_results_task: AsyncResult = job.pipeline.save_results_async(results=results, job_id=job.pk) - save_tasks.append((i + 1, save_results_task)) - job.logger.info(f"Saving results for batch {i+1} in sub-task {save_results_task.id}") + job.logger.info(f"Processing {image_count} images with pipeline {job.pipeline.slug}") + request_sent = time.time() + try: + # Ensures queues we subscribe to are always up to date + logger.info("Subscribe to all pipeline queues prior to processing...") + worker_name = get_worker_name() + subscribe_celeryworker_to_pipeline_queues(worker_name) - job.progress.update_stage( - "process", - status=JobState.STARTED, - progress=(i + 1) / len(chunks), - processed=min((i + 1) * chunk_size, image_count), - failed=len(request_failed_images), - remaining=max(image_count - ((i + 1) * chunk_size), 0), + job.pipeline.schedule_process_images( + images=images, + job_id=job.pk, + project_id=job.project.pk, ) - - # count the completed, successful, and failed save_tasks: - save_tasks_completed = [t for t in save_tasks if t[1].ready()] - failed_save_tasks = [t for t in save_tasks_completed if not t[1].successful()] - - for failed_batch_num, failed_task in failed_save_tasks: - # First log all errors and update the job status. Then raise exception if any failed. - job.logger.error(f"Failed to save results from batch {failed_batch_num} (sub-task {failed_task.id})") - - job.progress.update_stage( - "results", - status=JobState.FAILURE if failed_save_tasks else JobState.STARTED, - progress=len(save_tasks_completed) / len(chunks), - captures=total_captures, - detections=total_detections, - classifications=total_classifications, + job.logger.info( + "Submitted batch image processing tasks " + f"(task_name={MLSubtaskNames.process_pipeline_request.name}) in " + f"{time.time() - request_sent: .2f}s" ) - job.save() - # Stop processing if any save tasks have failed - # Otherwise, calculate the percent of images that have failed to save - throw_on_save_error = True - for failed_batch_num, failed_task in failed_save_tasks: - if throw_on_save_error: - failed_task.maybe_throw() - - if image_count: - percent_successful = 1 - len(request_failed_images) / image_count if image_count else 0 - job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") - - # Check all Celery sub-tasks if they have completed saving results - save_tasks_remaining = set(save_tasks) - set(save_tasks_completed) - job.logger.info( - f"Checking the status of {len(save_tasks_remaining)} remaining sub-tasks that are still saving results." - ) - for batch_num, sub_task in save_tasks: - if not sub_task.ready(): - job.logger.info(f"Waiting for batch {batch_num} to finish saving results (sub-task {sub_task.id})") - # @TODO this is not recommended! Use a group or chain. But we need to refactor. - # https://docs.celeryq.dev/en/latest/userguide/tasks.html#avoid-launching-synchronous-subtasks - sub_task.wait(disable_sync_subtasks=False, timeout=60) - if not sub_task.successful(): - error: Exception = sub_task.result - job.logger.error(f"Failed to save results from batch {batch_num}! (sub-task {sub_task.id}): {error}") - sub_task.maybe_throw() - - job.logger.info(f"All tasks completed for job {job.pk}") - - FAILURE_THRESHOLD = 0.5 - if image_count and (percent_successful < FAILURE_THRESHOLD): - job.progress.update_stage("process", status=JobState.FAILURE) + except Exception as e: + job.logger.error(f"Failed to submit all images: {e}") + job.update_status(JobState.FAILURE) job.save() - raise Exception(f"Failed to process more than {int(FAILURE_THRESHOLD * 100)}% of images") + else: + subtasks = job.ml_task_records.filter(created_at__gte=job.started_at) + if subtasks.count() == 0: + # No tasks were scheduled, mark the job as done + job.logger.info("No subtasks were scheduled, ending the job.") + job.progress.update_stage( + "process", + status=JobState.SUCCESS, + progress=1, + ) + job.progress.update_stage( + "results", + status=JobState.SUCCESS, + progress=1, + ) + job.update_status(JobState.SUCCESS, save=False) + job.finished_at = datetime.datetime.now() + job.save() + else: + job.logger.info( + f"Continue processing the remaining {subtasks.count()} process image request subtasks." + ) + from django.db import transaction - job.progress.update_stage( - "process", - status=JobState.SUCCESS, - progress=1, - ) - job.progress.update_stage( - "results", - status=JobState.SUCCESS, - progress=1, - ) - job.update_status(JobState.SUCCESS, save=False) - job.finished_at = datetime.datetime.now() - job.save() + transaction.on_commit(lambda: check_ml_job_status.apply_async([job.pk])) + finally: + # TODO: clean up? + pass class DataStorageSyncJob(JobType): @@ -716,6 +1014,71 @@ def get_job_type_by_inferred_key(job: "Job") -> type[JobType] | None: return job_type +class MLSubtaskNames(str, OrderedEnum): + process_pipeline_request = "process_pipeline_request" + save_results = "save_results" + + +class MLSubtaskState(str, OrderedEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAIL = "FAIL" + REVOKED = "REVOKED" + + +class MLTaskRecord(BaseModel): + """ + A model to track the history of MLJob subtasks. + Allows us to track the history of source images in a job. + """ + + job = models.ForeignKey("Job", on_delete=models.CASCADE, related_name="ml_task_records") + task_id = models.CharField(max_length=255, null=True, blank=True) + source_images = models.ManyToManyField(SourceImage, related_name="ml_task_records") + task_name = models.CharField( + max_length=255, + default=MLSubtaskNames.process_pipeline_request.name, + choices=MLSubtaskNames.choices(), + ) + status = models.CharField( + max_length=255, + default=MLSubtaskState.STARTED.name, + choices=MLSubtaskState.choices(), + ) + + raw_results = models.JSONField(null=True, blank=True, default=dict) + raw_traceback = models.TextField(null=True, blank=True) + + # recreate a process_pipeline_request task + pipeline_request = SchemaField(PipelineRequest, null=True, blank=True) + # recreate a save_results task + pipeline_response = SchemaField(PipelineResultsResponse, null=True, blank=True) + + # track the progress of the job + num_captures = models.IntegerField(default=0, help_text="Same as number of source_images") + num_detections = models.IntegerField(default=0) + num_classifications = models.IntegerField(default=0) + + def __str__(self): + return f"MLTaskRecord(job={self.job.pk}, task_id={self.task_id}, task_name={self.task_name})" + + def clean(self): + if self.status == MLSubtaskState.PENDING.name and self.task_name != MLSubtaskNames.save_results.name: + raise ValueError(f"{self.task_name} tasks cannot have a PENDING status.") + + def kill_task(self): + """ + Kill the celery task associated with this MLTaskRecord. + """ + from config.celery_app import app as celery_app + + if self.task_id: + celery_app.control.revoke(self.task_id, terminate=True, signal="SIGTERM") + self.status = MLSubtaskState.REVOKED.name + self.save(update_fields=["status"]) + + class Job(BaseModel): """A job to be run by the scheduler""" @@ -724,6 +1087,7 @@ class Job(BaseModel): name = models.CharField(max_length=255) queue = models.CharField(max_length=255, default="default") + last_checked = models.DateTimeField(null=True, blank=True, default=datetime.datetime.now) scheduled_at = models.DateTimeField(null=True, blank=True) started_at = models.DateTimeField(null=True, blank=True) finished_at = models.DateTimeField(null=True, blank=True) @@ -784,6 +1148,9 @@ class Job(BaseModel): related_name="jobs", ) + # For type hints + ml_task_records: models.QuerySet["MLTaskRecord"] + def __str__(self) -> str: return f'#{self.pk} "{self.name}" ({self.status})' @@ -845,6 +1212,15 @@ def setup(self, save=True): if save: self.save() + def check_inprogress_subtasks(self) -> bool | None: + """ + Check the status of the sub-tasks and update the job progress accordingly. + + Returns True if all subtasks are completed, False if any are still in progress. + """ + job_type = self.job_type() + return job_type.check_inprogress_subtasks(job=self) + def run(self): """ Run the job. diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index b12271178..48884780e 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -33,7 +33,7 @@ def run_job(self, job_id: int) -> None: @task_postrun.connect(sender=run_job) @task_prerun.connect(sender=run_job) def update_job_status(sender, task_id, task, *args, **kwargs): - from ami.jobs.models import Job + from ami.jobs.models import Job, MLJob job_id = task.request.kwargs["job_id"] if job_id is None: @@ -48,9 +48,12 @@ def update_job_status(sender, task_id, task, *args, **kwargs): logger.error(f"No job found for task {task_id} or job_id {job_id}") return - task = AsyncResult(task_id) # I'm not sure if this is reliable - job.update_status(task.status, save=False) - job.save() + # NOTE: After calling run_job, only update the status if the job + # is not an ML job (this job should handle it's own status updates) + if job.job_type_key != MLJob.key: + task = AsyncResult(task_id) # I'm not sure if this is reliable + job.update_status(task.status, save=False) + job.save() @task_failure.connect(sender=run_job, retry=False) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 7ab066506..eb2ff30fe 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -1,13 +1,26 @@ # from rich import print +import datetime import logging +import time -from django.test import TestCase +from django.db import connection + +# import pytest +from django.test import TestCase, TransactionTestCase from guardian.shortcuts import assign_perm from rest_framework import status from rest_framework.test import APIRequestFactory, APITestCase from ami.base.serializers import reverse_with_params -from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob +from ami.jobs.models import ( + Job, + JobProgress, + JobState, + MLJob, + MLSubtaskNames, + MLSubtaskState, + SourceImageCollectionPopulateJob, +) from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.users.models import User @@ -201,3 +214,290 @@ def test_cancel_job(self): # This cannot be tested until we have a way to cancel jobs # and a way to run async tasks in tests. pass + + +class TestMLJobBatchProcessing(TransactionTestCase): + def setUp(self): + self.project = Project.objects.first() # get the original test project + assert self.project + self.source_image_collection = self.project.sourceimage_collections.get(name="Test Source Image Collection") + self.pipeline = Pipeline.objects.get(slug="constant") + + # remove existing detections from the source image collection + for image in self.source_image_collection.images.all(): + image.detections.all().delete() + image.save() + + def test_run_ml_job(self): + """Test running a batch processing job end-to-end.""" + from celery.result import AsyncResult + + from ami.ml.tasks import check_ml_job_status + from config import celery_app + + logger.info( + f"Starting test_batch_processing_job using collection " + f"{self.source_image_collection} which contains " + f"{self.source_image_collection.images.count()} images " + f"and project {self.project}" + ) + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test batch processing", + delay=1, + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + self.assertEqual(job.progress.stages[0].key, "delay") + self.assertEqual(job.progress.stages[0].progress, 0) + self.assertEqual(job.progress.stages[0].status, JobState.CREATED) + self.assertEqual(job.progress.stages[1].key, "collect") + self.assertEqual(job.progress.stages[1].progress, 0) + self.assertEqual(job.progress.stages[1].status, JobState.CREATED) + self.assertEqual(job.progress.stages[2].key, "process") + self.assertEqual(job.progress.stages[2].progress, 0) + self.assertEqual(job.progress.stages[2].status, JobState.CREATED) + self.assertEqual(job.progress.stages[3].key, "results") + self.assertEqual(job.progress.stages[3].progress, 0) + self.assertEqual(job.progress.stages[3].status, JobState.CREATED) + + self.assertEqual(job.status, JobState.CREATED.value) + self.assertEqual(job.progress.summary.progress, 0) + self.assertEqual(job.progress.summary.status, JobState.CREATED) + + inspector = celery_app.control.inspect() + # Ensure workers are available + self.assertEqual(len(inspector.active()), 1, "No celery workers are running.") + + # -- Begin helper functions for checking celery tasks and worker stats --# + + def check_all_celery_tasks(): + active = inspector.active() + scheduled = inspector.scheduled() + reserved = inspector.reserved() + active_tasks = sum(len(v) for v in active.values()) if active else 0 + scheduled_tasks = sum(len(v) for v in scheduled.values()) if scheduled else 0 + reserved_tasks = sum(len(v) for v in reserved.values()) if reserved else 0 + total_tasks = active_tasks + scheduled_tasks + reserved_tasks + # Log the number of tasks for debugging + logger.info( + f"Celery tasks - Active: {active_tasks}, Scheduled: {scheduled_tasks}, Reserved: {reserved_tasks}, " + f"Total: {total_tasks}" + ) + return total_tasks + + def check_celery_worker_stats(): + i = celery_app.control.inspect() + results = i.stats() + if not results: + logger.warning("No celery results available.") + return False + for worker, stats in results.items(): + if stats.get("total", 0) == 0: + logger.warning(f"No tasks have been processed by worker {worker}.") + return False + else: + logger.info(f"Worker {worker} stats: {stats}") + return True + + def get_ml_job_subtask_details(task_name, job): + from ami.jobs.models import MLSubtaskNames + + assert task_name in [name.value for name in MLSubtaskNames] + logger.info(f"Checking ML task details for task: {task_name}") + + task_ids = job.ml_task_records.filter(task_name=task_name).values_list("task_id", flat=True) + + details = {} + for task_id in task_ids: + try: + async_result = AsyncResult(task_id, app=celery_app) + task_info = { + "id": async_result.id, + "status": async_result.status, + "successful": async_result.successful() if async_result.ready() else None, + "result": async_result.result if async_result.ready() else None, + "traceback": async_result.traceback if async_result.failed() else None, + "date_done": str(getattr(async_result, "date_done", None)), + "name": async_result.name, + } + details[task_id] = task_info + logger.info(f"Task {task_id} details: {task_info}") + except Exception as e: + logger.error(f"Error fetching details for task {task_id}: {e}") + details[task_id] = {"error": str(e)} + + return details + + # -- End helper functions --# + + job.run() + connection.commit() + job.refresh_from_db() + + start_time = time.time() + timeout = 30 # seconds + elapsed_time = 0 + + while elapsed_time < timeout: + if job.status == JobState.SUCCESS.value or job.status == JobState.FAILURE.value: + break + elapsed_time = time.time() - start_time + logger.info(f"Waiting for job to complete... elapsed time: {elapsed_time:.2f} seconds") + + check_all_celery_tasks() + check_celery_worker_stats() + + get_ml_job_subtask_details("process_pipeline_request", job) + get_ml_job_subtask_details("save_results", job) + + # Update the job status/progress within the test to get the latest db values + check_ml_job_status(job.pk) + MLJob.update_job_progress(job) + + # Check all subtasks were successful + ml_subtask_records = job.ml_task_records.all() + self.assertTrue(all(subtask.status == MLSubtaskState.SUCCESS.value for subtask in ml_subtask_records)) + + # Ensure a unique process_pipeline_request task was created per image + total_images = self.source_image_collection.images.count() + process_tasks = ml_subtask_records.filter(task_name=MLSubtaskNames.process_pipeline_request.value) + self.assertEqual(process_tasks.count(), total_images) + task_ids = process_tasks.values_list("task_id", flat=True).distinct() + self.assertEqual(task_ids.count(), total_images) + + # Each source image should be part of 2 tasks: process_pipeline_request and save_results + for image in self.source_image_collection.images.all(): + tasks_for_image = ml_subtask_records.filter(source_images=image) + self.assertEqual( + tasks_for_image.count(), + 2, + f"Image {image.id} is part of {tasks_for_image.count()} tasks instead of 2", + ) + + task_names = set(tasks_for_image.values_list("task_name", flat=True)) + self.assertEqual( + task_names, + {MLSubtaskNames.process_pipeline_request.value, MLSubtaskNames.save_results.value}, + f"Image {image.id} has tasks {task_names} instead of the expected ones", + ) + + logger.info( + f"Every source image was part of 2 tasks " + f"(process_pipeline_request and save_results). " + f"Job {job.pk} completed in {elapsed_time:.2f} seconds " + f"with status {job.status}" + ) + + # Check all the progress stages are marked as SUCCESS + self.assertEqual(job.status, JobState.SUCCESS.value) + self.assertEqual(job.progress.stages[0].key, "delay") + self.assertEqual(job.progress.stages[0].progress, 1) + self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS) + self.assertEqual(job.progress.stages[1].key, "collect") + self.assertEqual(job.progress.stages[1].progress, 1) + self.assertEqual(job.progress.stages[1].status, JobState.SUCCESS) + self.assertEqual(job.progress.stages[2].key, "process") + self.assertEqual(job.progress.stages[2].progress, 1) + self.assertEqual(job.progress.stages[2].status, JobState.SUCCESS) + self.assertEqual(job.progress.stages[3].key, "results") + self.assertEqual(job.progress.stages[3].progress, 1) + self.assertEqual(job.progress.stages[3].status, JobState.SUCCESS) + + self.assertEqual(job.status, JobState.SUCCESS.value) + self.assertEqual(job.progress.summary.progress, 1) + self.assertEqual(job.progress.summary.status, JobState.SUCCESS) + + +class TestStaleMLJob(TransactionTestCase): + def setUp(self): + self.project = Project.objects.first() # get the original test project + assert self.project + self.source_image_collection = self.project.sourceimage_collections.get(name="Test Source Image Collection") + self.pipeline = Pipeline.objects.get(slug="constant") + + # remove existing detections from the source image collection + for image in self.source_image_collection.images.all(): + image.detections.all().delete() + image.save() + + def test_kill_dangling_ml_job(self): + """Test killing a dangling ML job.""" + from ami.ml.tasks import check_dangling_ml_jobs + from config import celery_app + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test dangling job", + delay=0, + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + job.run() + connection.commit() + job.refresh_from_db() + + # Simulate last_checked being older than 24 hours + job.last_checked = datetime.datetime.now() - datetime.timedelta(hours=25) + job.save(update_fields=["last_checked"]) + + # Run the dangling job checker + check_dangling_ml_jobs() + + # Refresh job from DB + job.refresh_from_db() + + # Make sure no tasks are still in progress + for ml_task_record in job.ml_task_records.all(): + self.assertEqual(ml_task_record.status, MLSubtaskState.REVOKED.value) + + # Also check celery queue to make sure all tasks have been revoked + task_id = ml_task_record.task_id + + inspector = celery_app.control.inspect() + active = inspector.active() or {} + reserved = inspector.reserved() or {} + + not_running = all( + task_id not in [t["id"] for w in active.values() for t in w] for w in active.values() + ) and all(task_id not in [t["id"] for w in reserved.values() for t in w] for w in reserved.values()) + + self.assertTrue(not_running) + + self.assertEqual(job.status, JobState.REVOKED.value) + + def test_kill_task_prevents_execution(self): + from ami.jobs.models import Job, MLSubtaskNames, MLTaskRecord + from ami.ml.models.pipeline import process_pipeline_request + from config import celery_app + + logger.info("Testing that killing a task prevents its execution.") + result = process_pipeline_request.apply_async(args=[{}, 1], countdown=5) + logger.info(f"Scheduled task with id {result.id} to run in 5 seconds.") + task_id = result.id + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test killing job tasks", + delay=0, + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + ml_task_record = MLTaskRecord.objects.create( + job=job, task_name=MLSubtaskNames.process_pipeline_request.value, task_id=task_id + ) + logger.info(f"Killing task {task_id} immediately.") + ml_task_record.kill_task() + + async_result = celery_app.AsyncResult(task_id) + time.sleep(5) # the REVOKED STATUS isn't visible until the task is actually run after the delay + + self.assertIn(async_result.state, ["REVOKED"]) + self.assertEqual(ml_task_record.status, "REVOKED") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 5fffdb6fd..001947f76 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -14,7 +14,7 @@ from ami.utils.fields import url_boolean_param from ami.utils.requests import project_id_doc_param -from .models import Job, JobState +from .models import Job, JobState, MLJob from .serializers import JobListSerializer, JobSerializer logger = logging.getLogger(__name__) @@ -156,3 +156,21 @@ def get_queryset(self) -> QuerySet: @extend_schema(parameters=[project_id_doc_param]) def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) + + @action(detail=True, methods=["post"], name="check-inprogress-subtasks") + def check_inprogress_subtasks(self, request, pk=None): + """ + Check in-progress subtasks for a job. + """ + # @TODO: add additional stats here? i.e. time fo each task, progress stats + job: Job = self.get_object() + assert job.job_type_key == MLJob.key, f"{job} is not an ML job." + has_inprogress_tasks = job.check_inprogress_subtasks() + if has_inprogress_tasks: + # Schedule task to update the job status + from django.db import transaction + + from ami.ml.tasks import check_ml_job_status + + transaction.on_commit(lambda: check_ml_job_status.apply_async((job.pk,))) + return Response({"inprogress_subtasks": has_inprogress_tasks}) diff --git a/ami/main/models.py b/ami/main/models.py index f672c2832..dc38d837c 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -904,7 +904,9 @@ def save(self, update_calculated_fields=True, regroup_async=True, *args, **kwarg if deployment_events_need_update(self): logger.info(f"Deployment {self} has events that need to be regrouped") if regroup_async: - ami.tasks.regroup_events.delay(self.pk) + transaction.on_commit( + lambda: ami.tasks.regroup_events.delay(self.pk) + ) # enqueue the task only after the DB commit completes else: group_images_into_events(self) self.update_calculated_fields(save=True) diff --git a/ami/ml/apps.py b/ami/ml/apps.py index 6b6752c1c..31e208229 100644 --- a/ami/ml/apps.py +++ b/ami/ml/apps.py @@ -5,3 +5,6 @@ class MLConfig(AppConfig): name = "ami.ml" verbose_name = _("Machine Learning") + + def ready(self): + import ami.ml.signals # noqa: F401 diff --git a/ami/ml/migrations/0026_check_dangling_ml_jobs_celery_beat.py b/ami/ml/migrations/0026_check_dangling_ml_jobs_celery_beat.py new file mode 100644 index 000000000..1bed47247 --- /dev/null +++ b/ami/ml/migrations/0026_check_dangling_ml_jobs_celery_beat.py @@ -0,0 +1,33 @@ +from django.db import migrations +from django_celery_beat.models import PeriodicTask, CrontabSchedule + + +def create_periodic_task(apps, schema_editor): + crontab_schedule, _ = CrontabSchedule.objects.get_or_create( + minute="*/5", # Every 5 minutes + hour="*", # Every hour + day_of_week="*", # Every day + day_of_month="*", # Every day of month + month_of_year="*", # Every month + ) + + PeriodicTask.objects.get_or_create( + name="celery.check_dangling_ml_jobs", + task="ami.ml.tasks.check_dangling_ml_jobs", + crontab=crontab_schedule, + ) + + +def delete_periodic_task(apps, schema_editor): + # Delete the task if rolling back + PeriodicTask.objects.filter(name="celery.check_dangling_ml_jobs").delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0025_alter_algorithm_task_type"), + ] + + operations = [ + migrations.RunPython(create_periodic_task, delete_periodic_task), + ] diff --git a/ami/ml/models/__init__.py b/ami/ml/models/__init__.py index 5000c7f53..1e4202e92 100644 --- a/ami/ml/models/__init__.py +++ b/ami/ml/models/__init__.py @@ -1,5 +1,5 @@ from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap -from ami.ml.models.pipeline import Pipeline +from ami.ml.models.pipeline import Pipeline, PipelineSaveResults from ami.ml.models.processing_service import ProcessingService from ami.ml.models.project_pipeline_config import ProjectPipelineConfig @@ -7,6 +7,7 @@ "Algorithm", "AlgorithmCategoryMap", "Pipeline", + "PipelineSaveResults", "ProcessingService", "ProjectPipelineConfig", ] diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index 6486e26d9..5db3b9e91 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -15,7 +15,7 @@ from urllib.parse import urljoin import requests -from django.db import models +from django.db import models, transaction from django.utils.text import slugify from django.utils.timezone import now from django_pydantic_field import SchemaField @@ -160,15 +160,45 @@ def collect_images( return images +@celery_app.task(name="process_pipeline_request") +def process_pipeline_request(pipeline_request: dict, project_id: int): + # TODO: instead of dict can we use pipeline request object? + """ + This is the primary function for processing images on the antenna side. + Workers have a function of the same name which will run their own inference/processing logic. + On the antenna side, we use external servers via an API to process images. + """ + request_data = PipelineRequest(**pipeline_request) + source_image_requests = request_data.source_images + source_images = [] + for req in source_image_requests: + source_images.append(SourceImage.objects.get(pk=req.id)) + + results = process_images( + pipeline=Pipeline.objects.get(slug=request_data.pipeline), + images=source_images, + process_sync=True, + project_id=project_id, + ) + assert results is not None, "process_sync=True should return a valid PipelineResultsResponse, not None." + return results.dict() + + def process_images( pipeline: Pipeline, - endpoint_url: str, images: typing.Iterable[SourceImage], + project_id: int, job_id: int | None = None, - project_id: int | None = None, -) -> PipelineResultsResponse: + process_sync: bool = False, +) -> PipelineResultsResponse | None: """ - Process images using ML pipeline API. + Process images. + + If process_sync is True, immediately process the images via requests to the /process endpoint + and return a PipelineResultsResponse. + + Otherwise, submit async processing tasks and return None. + This is only applicable to MLJobs which check the status of these tasks. """ job = None task_logger = logger @@ -179,13 +209,12 @@ def process_images( job = Job.objects.get(pk=job_id) task_logger = job.logger - if project_id: - project = Project.objects.get(pk=project_id) - else: - task_logger.warning(f"Pipeline {pipeline} is not associated with a project") - project = None + # Pipelines must be associated with a project in order to select a processing service + # A processing service is required to send requests to the /process endpoint + project = Project.objects.get(pk=project_id) + task_logger.info(f"Using project: {project}") - pipeline_config = pipeline.get_config(project_id=project_id) + pipeline_config = pipeline.get_config(project_id=project.pk) task_logger.info(f"Using pipeline config: {pipeline_config}") prefiltered_images = list(images) @@ -195,13 +224,16 @@ def process_images( task_logger.info(f"Ignoring {len(prefiltered_images) - len(images)} images that have already been processed") if not images: - task_logger.info("No images to process") - return PipelineResultsResponse( - pipeline=pipeline.slug, - source_images=[], - detections=[], - total_time=0, - ) + task_logger.info("No images to process.") + if process_sync: + return PipelineResultsResponse( + pipeline=pipeline.slug, + source_images=[], + detections=[], + total_time=0, + ) + else: + return None task_logger.info(f"Sending {len(images)} images to Pipeline {pipeline}") urls = [source_image.public_url() for source_image in images if source_image.public_url()] @@ -231,6 +263,120 @@ def process_images( else: task_logger.info("Reprocessing of existing detections is disabled, sending images without detections.") + task_logger.info(f"Found {len(detection_requests)} existing detections.") + + if not process_sync: + assert job_id is not None, "job_id is required to process images using async tasks." + handle_async_process_images( + pipeline.slug, + source_image_requests, + images, + pipeline_config, + detection_requests, + project_id, + job_id, + task_logger, + ) + return + else: + results = handle_sync_process_images( + pipeline, source_image_requests, pipeline_config, detection_requests, job_id, task_logger, project_id, job + ) + return results + + +def handle_async_process_images( + pipeline: str, + source_image_requests: list[SourceImageRequest], + source_images: list[SourceImage], + pipeline_config: PipelineRequestConfigParameters, + detection_requests: list[DetectionRequest], + project_id: int, + job_id: int, + task_logger: logging.Logger = logger, +): + """Handle asynchronous processing by submitting tasks to the appropriate pipeline queue.""" + batch_size = pipeline_config.get("batch_size", 1) + + # Group source images into batches + source_image_request_batches = [] + source_image_batches = [] + + for i in range(0, len(source_image_requests), batch_size): + request_batch = [] + image_batch = [] + for j in range(batch_size): + if i + j >= len(source_image_requests): + break + request_batch.append(source_image_requests[i + j]) + image_batch.append(source_images[i + j]) + source_image_request_batches.append(request_batch) + source_image_batches.append(image_batch) + + # Group the detections into batches based on its source image + for idx, source_images_batch in enumerate(source_image_request_batches): + detections_batch = [ + detection + for detection in detection_requests + if detection.source_image.id in [img.id for img in source_images_batch] + ] + prediction_request = PipelineRequest( + pipeline=pipeline, + source_images=source_images_batch, + detections=detections_batch, + config=pipeline_config, + ) + + task_id = str(uuid.uuid4()) + # use transaction on commit to ensure source images and other project details are finished saving + transaction.on_commit( + lambda: process_pipeline_request.apply_async( + args=[prediction_request.dict(), project_id], + task_id=task_id, + # TODO: make ml-pipeline an environment variable (i.e. PIPELINE_QUEUE_PREFIX)? + queue=f"ml-pipeline-{pipeline}", + # all pipelines have their own queue beginning with "ml-pipeline-" + # the antenna celeryworker should subscribe to all pipeline queues + ) + ) + + if job_id: + from ami.jobs.models import Job, MLTaskRecord + + job = Job.objects.get(pk=job_id) + # Create a new MLTaskRecord for this task + ml_task_record = MLTaskRecord.objects.create( + job=job, + task_id=task_id, + task_name="process_pipeline_request", + pipeline_request=prediction_request, + num_captures=len(source_image_batches[idx]), + ) + ml_task_record.source_images.set(source_image_batches[idx]) + ml_task_record.save() + task_logger.debug(f"Created MLTaskRecord {ml_task_record} for task {task_id}") + else: + task_logger.warning("No job ID provided, MLTaskRecord will not be created.") + + task_logger.info(f"Submitted {len(source_image_request_batches)} batch image processing task(s).") + + +def handle_sync_process_images( + pipeline: Pipeline, + source_image_requests: list[SourceImageRequest], + pipeline_config: PipelineRequestConfigParameters, + detection_requests: list[DetectionRequest], + job_id: int | None, + task_logger: logging.Logger, + project_id: int, + job: Job | None, +) -> PipelineResultsResponse: + """Handle synchronous processing by sending HTTP requests to the processing service.""" + processing_service = pipeline.choose_processing_service_for_pipeline(job_id, pipeline.name, project_id) + if not processing_service.endpoint_url: + raise ValueError(f"No endpoint URL configured for this pipeline's processing service ({processing_service})") + endpoint_url = urljoin(processing_service.endpoint_url, "/process") + request_data = PipelineRequest( pipeline=pipeline.slug, source_images=source_image_requests, @@ -1021,6 +1167,9 @@ def get_config(self, project_id: int | None = None) -> PipelineRequestConfigPara ) except self.project_pipeline_configs.model.DoesNotExist as e: logger.warning(f"No project-pipeline config for Pipeline {self} " f"and Project #{project_id}: {e}") + else: + logger.warning("No project_id. No pipeline config is used. Using default empty config instead.") + return config def collect_images( @@ -1089,20 +1238,32 @@ def choose_processing_service_for_pipeline( return processing_service_lowest_latency - def process_images(self, images: typing.Iterable[SourceImage], project_id: int, job_id: int | None = None): - processing_service = self.choose_processing_service_for_pipeline(job_id, self.name, project_id) - - if not processing_service.endpoint_url: - raise PipelineNotConfigured( - f"No endpoint URL configured for this pipeline's processing service ({processing_service})" - ) + def process_images( + self, + images: typing.Iterable[SourceImage], + project_id: int, + job_id: int | None = None, + ): + return process_images( + pipeline=self, + images=images, + job_id=job_id, + project_id=project_id, + process_sync=True, + ) + def schedule_process_images( + self, + images: typing.Iterable[SourceImage], + project_id: int, + job_id: int | None = None, + ): return process_images( - endpoint_url=urljoin(processing_service.endpoint_url, "/process"), pipeline=self, images=images, job_id=job_id, project_id=project_id, + process_sync=False, ) def save_results(self, results: PipelineResultsResponse, job_id: int | None = None): diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 49e5efd8f..81ec2dcae 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -1,6 +1,7 @@ import datetime import logging import typing +from statistics import mean import pydantic @@ -213,6 +214,41 @@ class PipelineResultsResponse(pydantic.BaseModel): detections: list[DetectionResponse] errors: list | str | None = None + def combine_with(self, others: list["PipelineResultsResponse"]) -> "PipelineResultsResponse": + """ + Combine this PipelineResultsResponse with others. + Returns a new combined PipelineResultsResponse. + """ + if not others: + return self + + all_responses = [self] + others + + pipelines = {r.pipeline for r in all_responses} + if len(pipelines) != 1: + raise AssertionError(f"Inconsistent pipelines: {pipelines}") + + algorithms_list = [r.algorithms for r in all_responses] + if not all(a == algorithms_list[0] for a in algorithms_list): + raise AssertionError("Algorithm configurations differ among responses.") + + errors_found = [r.errors for r in all_responses if r.errors] + if errors_found: + raise AssertionError(f"Some responses contain errors: {errors_found}") + + combined_source_images = [img for r in all_responses for img in r.source_images] + combined_detections = [det for r in all_responses for det in r.detections] + avg_total_time = mean(r.total_time for r in all_responses) + + return PipelineResultsResponse( + pipeline=self.pipeline, + algorithms=self.algorithms, + total_time=avg_total_time, + source_images=combined_source_images, + detections=combined_detections, + errors=None, + ) + class PipelineStageParam(pydantic.BaseModel): """A configurable parameter of a stage of a pipeline.""" diff --git a/ami/ml/signals.py b/ami/ml/signals.py new file mode 100644 index 000000000..fc03a5b4e --- /dev/null +++ b/ami/ml/signals.py @@ -0,0 +1,104 @@ +import logging + +from celery.signals import worker_ready +from django.db.models.signals import post_delete, post_save +from django.dispatch import receiver + +from ami.ml.models.pipeline import Pipeline +from config.celery_app import app as celery_app + +logger = logging.getLogger(__name__) + +ANTENNA_CELERY_WORKER_NAME = "antenna_celeryworker" + + +def get_worker_name(): + """ + Find the antenna celery worker's node name. + This is not always possible, especially if called too early during startup. + """ + try: + inspector = celery_app.control.inspect() + active_workers = inspector.active() + if active_workers: # TODO: currently only works if there is one worker + # NOTE: all antenna celery workers should have ANTENNA_CELERY_WORKER_NAME + # in their name instead of the the default "celery" + return next((worker for worker in active_workers.keys() if ANTENNA_CELERY_WORKER_NAME in worker), None) + except Exception as e: + logger.warning(f"Could not find antenna celery worker name: {e}") + + +@worker_ready.connect +def subscribe_celeryworker_to_pipeline_queues(sender, **kwargs) -> bool: + """ + When the antenna worker is fully up, enqueue the subscription task. + + Returns True if subscriptions were successful, False otherwise. + """ + if type(sender) == str: + worker_name = sender + elif sender is None: + worker_name = get_worker_name() + else: + worker_name = sender.hostname # e.g. "ANTENNA_CELERY_WORKER_NAME@" + assert worker_name, "Could not determine worker name; cannot subscribe to pipeline queues." + pipelines = Pipeline.objects.values_list("slug", flat=True) + + if not worker_name.startswith(f"{ANTENNA_CELERY_WORKER_NAME}@"): + logger.warning( + f"Worker name '{worker_name}' does not match expected pattern " + f"'{ANTENNA_CELERY_WORKER_NAME}@'. Cannot subscribe to pipeline queues.", + ) + return False + + if not pipelines: + # TODO: kinda hacky. is there a way to unify the django and celery logs + # to more easily see which queues the worker is subscribed to? + raise ValueError("No pipelines found; cannot subscribe to any queues.") + + for slug in pipelines: + queue_name = f"ml-pipeline-{slug}" + try: + celery_app.control.add_consumer(queue_name, destination=[worker_name]) + logger.info(f"Subscribed worker '{worker_name}' to queue '{queue_name}'") + except Exception as e: + logger.exception(f"Failed to subscribe '{worker_name}' to queue '{queue_name}': {e}") + + return True + + +@receiver(post_save, sender=Pipeline) +def pipeline_created(sender, instance, created, **kwargs): + if not created: + return + + try: + queue_name = f"ml-pipeline-{instance.slug}" + worker_name = get_worker_name() + + assert worker_name, ( + "Could not determine worker name; cannot subscribe to new queue " + f"{queue_name}. This might be an expected error if the worker hasn't " + "started or is ready to accept connections." + ) + + celery_app.control.add_consumer(queue_name, destination=[worker_name]) + logger.info(f"Queue '{queue_name}' successfully added to worker '{worker_name}'") + except Exception as e: + logger.exception(f"Failed to add queue '{queue_name}' to worker '{worker_name}': {e}.") + + +@receiver(post_delete, sender=Pipeline) +def pipeline_deleted(sender, instance, **kwargs): + queue_name = f"ml-pipeline-{instance.slug}" + logger.info(f"Unsubscribing queue '{queue_name}' from the celeryworker...") + worker_name = get_worker_name() + + try: + if not worker_name: + raise ValueError("Could not determine worker name; cannot unsubscribe from queue.") + + celery_app.control.cancel_consumer(queue_name, destination=[worker_name]) + logger.info(f"Queue '{queue_name}' successfully unsubscribed from worker '{worker_name}'") + except Exception as e: + logger.exception(f"Failed to unsubscribe queue '{queue_name}' for worker '{worker_name}': {e}") diff --git a/ami/ml/tasks.py b/ami/ml/tasks.py index 47a8ef857..c76e5af78 100644 --- a/ami/ml/tasks.py +++ b/ami/ml/tasks.py @@ -1,3 +1,4 @@ +import datetime import logging import time @@ -8,8 +9,9 @@ logger = logging.getLogger(__name__) +# @TODO: Deprecate this? is this still needed? @celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) -def process_source_images_async(pipeline_choice: str, endpoint_url: str, image_ids: list[int], job_id: int | None): +def process_source_images_async(pipeline_choice: str, image_ids: list[int], job_id: int | None): from ami.jobs.models import Job from ami.main.models import SourceImage from ami.ml.models.pipeline import Pipeline, process_images, save_results @@ -24,13 +26,10 @@ def process_source_images_async(pipeline_choice: str, endpoint_url: str, image_i images = SourceImage.objects.filter(pk__in=image_ids) pipeline = Pipeline.objects.get(slug=pipeline_choice) + project = pipeline.projects.first() + assert project, f"Pipeline {pipeline} must be associated with a project." - results = process_images( - pipeline=pipeline, - endpoint_url=endpoint_url, - images=images, - job_id=job_id, - ) + results = process_images(pipeline=pipeline, images=images, job_id=job_id, project_id=project.pk) try: save_results(results=results, job_id=job_id) @@ -55,7 +54,7 @@ def create_detection_images(source_image_ids: list[int]): logger.error(f"Error creating detection images for SourceImage {source_image.pk}: {str(e)}") total_time = time.time() - start_time - logger.info(f"Created detection images for {len(source_image_ids)} capture(s) in {total_time:.2f} seconds") + logger.info(f"Created detection images for {len(source_image_ids)} capture(s) in {total_time: .2f} seconds") @celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) @@ -106,3 +105,84 @@ def check_processing_services_online(): except Exception as e: logger.error(f"Error checking service {service}: {e}") continue + + +@celery_app.task() # TODO: add a time limit? stay active for as long as the ML job will take +def check_ml_job_status(ml_job_id: int): + """ + Check the status of a specific ML job's inprogress subtasks and update its status accordingly. + """ + from ami.jobs.models import Job, JobState, MLJob + + job = Job.objects.get(pk=ml_job_id) + assert job.job_type_key == MLJob.key, f"{ml_job_id} is not an ML job." + + try: + logger.info(f"Checking status for job {job}.") + logger.info(f"Job subtasks are: {job.ml_task_records.all()}.") + jobs_complete = job.check_inprogress_subtasks() + logger.info(f"Successfully checked status for job {job}. .") + job.last_checked = datetime.datetime.now() + job.save(update_fields=["last_checked"]) + + if jobs_complete: + logger.info(f"ML Job {ml_job_id} is complete.") + job.logger.info(f"ML Job {ml_job_id} is complete.") + else: + from django.db import transaction + + logger.info(f"ML Job {ml_job_id} still in progress. Checking again for completed tasks.") + transaction.on_commit(lambda: check_ml_job_status.apply_async([ml_job_id], countdown=5)) + except Job.DoesNotExist: + raise ValueError(f"Job with ID {ml_job_id} does not exist.") + except Exception as e: + error_msg = f"Error checking status for job with ID {ml_job_id}: {e}" + job.logger.error(error_msg) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + + # Remove remaining tasks from the queue + for ml_task_record in job.ml_task_records.all(): + ml_task_record.kill_task() + + raise Exception(error_msg) + + +@celery_app.task(soft_time_limit=600, time_limit=800) +def check_dangling_ml_jobs(): + """ + An inprogress ML job is dangling if the last_checked time + is older than 5 minutes. + """ + import datetime + + from ami.jobs.models import Job, JobState, MLJob + + inprogress_jobs = Job.objects.filter(job_type_key=MLJob.key, status=JobState.STARTED.name) + logger.info(f"Found {inprogress_jobs.count()} inprogress ML jobs to check for dangling tasks.") + + for job in inprogress_jobs: + last_checked = job.last_checked + if not last_checked: + logger.warning(f"Job {job.pk} has no last_checked time. Marking as dangling.") + seconds_since_checked = float("inf") + else: + seconds_since_checked = (datetime.datetime.now() - last_checked).total_seconds() + if last_checked is None or seconds_since_checked > 24 * 60 * 60: # 24 hours + logger.warning( + f"Job {job.pk} appears to be dangling since {last_checked} " + f"was {seconds_since_checked} ago. Marking as failed." + ) + job.logger.warning( + f"Job {job.pk} appears to be dangling since {last_checked} " + f"was {seconds_since_checked} ago. Marking as failed." + ) + job.update_status(JobState.REVOKED) + job.finished_at = datetime.datetime.now() + job.save() + + for ml_task_record in job.ml_task_records.all(): + ml_task_record.kill_task() + else: + logger.info(f"Job {job.pk} is active. Last checked at {last_checked}.") diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index 398a27605..b47dccbf4 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -128,6 +128,7 @@ def setup_test_project(reuse=True) -> tuple[Project, Deployment]: deployment = Deployment.objects.filter(project=project).filter(name__contains=short_id).latest("created_at") assert deployment, f"No deployment found for project {project}. Recreate the project." + return project, deployment diff --git a/compose/local/django/Dockerfile b/compose/local/django/Dockerfile index 0e778f82b..29d222a25 100644 --- a/compose/local/django/Dockerfile +++ b/compose/local/django/Dockerfile @@ -62,10 +62,6 @@ RUN sed -i 's/\r$//g' /start RUN chmod +x /start -COPY ./compose/local/django/celery/worker/start /start-celeryworker -RUN sed -i 's/\r$//g' /start-celeryworker -RUN chmod +x /start-celeryworker - COPY ./compose/local/django/celery/beat/start /start-celerybeat RUN sed -i 's/\r$//g' /start-celerybeat RUN chmod +x /start-celerybeat diff --git a/compose/local/django/celery/worker/start b/compose/local/django/celery/worker/start index a3482f44e..c8eb02b13 100755 --- a/compose/local/django/celery/worker/start +++ b/compose/local/django/celery/worker/start @@ -18,10 +18,13 @@ set -o nounset MAX_TASKS_PER_CHILD=100 MAX_MEMORY_PER_CHILD=1048576 # 1 GiB in KB +# start the worker with antenna_celeryworker to ensure it's discoverable by ami.ml.signals.get_worker_name +python manage.py migrate + # Launch VS Code debug server if DEBUGGER environment variable is set to 1 # Note that auto reloading is disabled when debugging, manual restart required for code changes. if [ "${DEBUGGER:-0}" = "1" ]; then - exec python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:5679 -m celery -A config.celery_app worker --queues=antenna -l INFO --max-tasks-per-child=$MAX_TASKS_PER_CHILD --max-memory-per-child=$MAX_MEMORY_PER_CHILD + exec python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:5679 -m celery -A config.celery_app worker --queues=antenna -n antenna_celeryworker@%h -l INFO --max-tasks-per-child=$MAX_TASKS_PER_CHILD --max-memory-per-child=$MAX_MEMORY_PER_CHILD else - exec watchfiles --filter python celery.__main__.main --args '-A config.celery_app worker --queues=antenna -l INFO --max-tasks-per-child='$MAX_TASKS_PER_CHILD' --max-memory-per-child='$MAX_MEMORY_PER_CHILD + exec watchfiles --filter python celery.__main__.main --args '-A config.celery_app worker --queues=antenna -n antenna_celeryworker@%h -l INFO --max-tasks-per-child='$MAX_TASKS_PER_CHILD' --max-memory-per-child='$MAX_MEMORY_PER_CHILD fi diff --git a/config/settings/base.py b/config/settings/base.py index 03124d41a..17ceeef85 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -99,7 +99,7 @@ "ami.users", "ami.main", "ami.jobs", - "ami.ml", + "ami.ml.apps.MLConfig", # Use the custom config instead of "ami.ml", "ami.labelstudio", "ami.exports", ] @@ -334,6 +334,7 @@ CELERY_WORKER_SEND_TASK_EVENTS = True # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-task_send_sent_event CELERY_TASK_SEND_SENT_EVENT = True +CELERY_TASK_DEFAULT_QUEUE = "antenna" # Health checking and retries if using Redis as results backend # https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis diff --git a/config/settings/local.py b/config/settings/local.py index c2f58afa0..0de65cc7c 100644 --- a/config/settings/local.py +++ b/config/settings/local.py @@ -88,5 +88,7 @@ # https://docs.celeryq.dev/en/stable/userguide/configuration.html#task-eager-propagates CELERY_TASK_EAGER_PROPAGATES = True +CELERY_TASK_DEFAULT_QUEUE = "antenna" + # Your stuff... # ------------------------------------------------------------------------------ diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index 8e93b684d..0f9b724c8 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -17,11 +17,18 @@ services: extra_hosts: - "host.docker.internal:host-gateway" depends_on: - - postgres - - redis - - minio-init - - ml_backend - - rabbitmq + postgres: + condition: service_started + redis: + condition: service_started + minio-init: + condition: service_started + rabbitmq: + condition: service_started + ml_backend: + condition: service_started + celeryworker: # required to subscribe the worker to the pipelines in the db + condition: service_healthy env_file: - ./.envs/.ci/.django - ./.envs/.ci/.postgres @@ -43,13 +50,27 @@ services: <<: *django depends_on: - rabbitmq - command: /start-celeryworker + # start the worker with antenna_celeryworker to ensure it's discoverable by ami.ml.signals.get_worker_name + command: + - sh + - -c + - | + python manage.py migrate && + python -m celery -A config.celery_app worker --queues=antenna -n antenna_celeryworker@%h -l INFO + healthcheck: + # make sure DATABASE_URL is inside the ./.envs/.ci/.postgres + test: ["CMD-SHELL", "celery -A config.celery_app inspect ping -d antenna_celeryworker@$(hostname) | grep -q pong"] + interval: 10s + timeout: 50s + retries: 5 + start_period: 10s rabbitmq: image: rabbitmq:3-management env_file: - ./.envs/.ci/.django + minio: image: minio/minio:RELEASE.2024-11-07T00-52-20Z command: minio server --console-address ":9001" /data @@ -79,6 +100,8 @@ services: context: ./processing_services/minimal volumes: - ./processing_services/minimal/:/app + depends_on: + - rabbitmq networks: default: aliases: diff --git a/docker-compose.yml b/docker-compose.yml index e2ad3a100..e8951df8f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,11 +19,18 @@ services: extra_hosts: - "host.docker.internal:host-gateway" depends_on: - - postgres - - redis - - minio-init - - ml_backend - - rabbitmq + postgres: + condition: service_started + redis: + condition: service_started + minio-init: + condition: service_started + ml_backend: + condition: service_started + rabbitmq: + condition: service_started + celeryworker: # required to subscribe the worker to the pipelines in the db + condition: service_healthy volumes: - .:/app:z env_file: @@ -83,15 +90,32 @@ services: redis: image: redis:6 container_name: ami_local_redis + networks: + - antenna_network + ports: + - "6379:6379" # expose redis port for setting celery task locks celeryworker: <<: *django image: ami_local_celeryworker scale: 1 ports: [] - command: /start-celeryworker + # start the worker with antenna_celeryworker to ensure it's discoverable by ami.ml.signals.get_worker_name + command: + - sh + - -c + - | + python manage.py migrate && + python -m debugpy --listen 0.0.0.0:5678 -m celery -A config.celery_app worker --queues=antenna -n antenna_celeryworker@%h -l INFO depends_on: - rabbitmq + healthcheck: + # make sure DATABASE_URL is inside the ./.envs/.local/.postgres + test: ["CMD-SHELL", "celery -A config.celery_app inspect ping -d antenna_celeryworker@$(hostname) | grep -q pong"] + interval: 10s + timeout: 50s + retries: 5 + start_period: 10s celerybeat: <<: *django @@ -149,8 +173,10 @@ services: env_file: - ./.envs/.local/.django depends_on: - - minio - - minio-proxy + minio: + condition: service_healthy + minio-proxy: + condition: service_started volumes: - ./compose/local/minio/init.sh:/etc/minio/init.sh entrypoint: /etc/minio/init.sh @@ -167,6 +193,7 @@ services: aliases: - processing_service + networks: antenna_network: name: antenna_network