Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ae02d2e
fix: Job serialization overhead
carlosgjs Sep 30, 2025
24a15af
syntax
carlosgjs Sep 30, 2025
0da97a6
fix syntax
carlosgjs Sep 30, 2025
2db7d66
Simplify diagram
carlosgjs Oct 1, 2025
8a714cd
Add RabbitMQ
carlosgjs Oct 3, 2025
700f594
WIP: Use NATS JetStream for queuing
carlosgjs Oct 8, 2025
3b42e08
Merge branch 'main' into carlosg/jobio
carlosgjs Oct 14, 2025
8ea5d7d
Saving of results
carlosgjs Oct 17, 2025
61fc2c5
Update progress
carlosgjs Oct 17, 2025
9af597c
Clean up and refactor task state mgmt
carlosgjs Oct 24, 2025
7ff8865
fix async use
carlosgjs Oct 24, 2025
0fbe899
Merge branch 'main' into carlosg/jobio
carlosgjs Oct 24, 2025
7899fc5
Fix circular dependency, jobset query by pipeline slug
carlosgjs Oct 24, 2025
d9f8ffd
GH review comments
carlosgjs Oct 24, 2025
edad552
Add feature flag, rename "job" to "task"
carlosgjs Oct 29, 2025
d254867
Code reorganization
carlosgjs Oct 31, 2025
1cc890e
Resolve circular deps
carlosgjs Oct 31, 2025
84ee5a2
Update ami/jobs/models.py
carlosgjs Oct 31, 2025
09fee92
cleanup
carlosgjs Oct 31, 2025
4480b0d
Consistent progress updates, single image job command
carlosgjs Nov 4, 2025
3032709
Fix typo
carlosgjs Nov 4, 2025
3e7ef3b
Merge branch 'main' into carlosg/jobio
carlosgjs Nov 5, 2025
04be994
Merge branch 'main' into carlosg/jobio
carlosgjs Nov 18, 2025
a8b94e3
Remove unnecesary file
carlosgjs Nov 18, 2025
1fc20b5
Merge branch 'main' into carlosg/jobio
carlosgjs Nov 21, 2025
0a5c89e
Remove diagram, fix flakes
carlosgjs Nov 21, 2025
344f883
Use async_to_sync
carlosgjs Nov 21, 2025
df7eaa3
CR feedback
carlosgjs Nov 21, 2025
0391642
clean up
carlosgjs Nov 21, 2025
4ae27b0
more cleanup
carlosgjs Nov 21, 2025
4f50b3d
Apply suggestions from code review
carlosgjs Nov 21, 2025
a8fc79a
Remove old comments
carlosgjs Nov 21, 2025
4efdf07
Fix processing error cases
carlosgjs Nov 21, 2025
f221a1a
updates
carlosgjs Nov 21, 2025
1a9b80a
Merge branch 'main' into carlosg/jobio
carlosgjs Dec 9, 2025
3657fd2
Fix merge bugs, back to working state
carlosgjs Dec 9, 2025
2483592
Use PipelineProcessingTask for the queue, other fixes
carlosgjs Dec 10, 2025
0ae9674
Update tests
carlosgjs Dec 10, 2025
3c034a9
General cleanup
carlosgjs Dec 10, 2025
e9d2a1c
Add nats to CI and prod
carlosgjs Dec 10, 2025
3d198d0
Unit tests for new classes
carlosgjs Dec 10, 2025
f9a1226
Add nats to staging, don't retry save resutls task
carlosgjs Dec 19, 2025
3a73329
fix formatting
carlosgjs Dec 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .envs/.ci/.django
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672/
CELERY_RESULT_BACKEND=rpc:// # Use RabbitMQ for results backend
RABBITMQ_DEFAULT_USER=rabbituser
RABBITMQ_DEFAULT_PASS=rabbitpass

# NATS
# ------------------------------------------------------------------------------
NATS_URL=nats://nats:4222
3 changes: 3 additions & 0 deletions .envs/.local/.django
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ DJANGO_SUPERUSER_PASSWORD=localadmin
# Redis
REDIS_URL=redis://redis:6379/0

# NATS
NATS_URL=nats://nats:4222

# Celery / Flower
CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT
CELERY_FLOWER_PASSWORD=BEQgmCtgyrFieKNoGTsux9YIye0I7P5Q7vEgfJD2C4jxmtHDetFaE2jhS7K7rxaf
Expand Down
4 changes: 4 additions & 0 deletions .envs/.production/.django-example
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ WEB_CONCURRENCY=4
DEFAULT_PROCESSING_SERVICE_NAME="AMI Data Companion"
DEFAULT_PROCESSING_SERVICE_ENDPOINT=https://ml.antenna.insectai.org/
DEFAULT_PIPELINES_ENABLED=global_moths_2024,quebec_vermont_moths_2023,panama_moths_2023,uk_denmark_moths_2023

# NATS
# ------------------------------------------------------------------------------
NATS_URL=nats://nats:4222
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ docker compose -f processing_services/example/docker-compose.yml up -d
- Django admin: http://localhost:8000/admin/
- OpenAPI / Swagger documentation: http://localhost:8000/api/v2/docs/
- Minio UI: http://minio:9001, Minio service: http://minio:9000
- NATS dashboard: https://natsdashboard.com/ (Add localhost)

NOTE: If one of these services is not working properly, it could be due another process is using the port. You can check for this with `lsof -i :<PORT_NUMBER>`.

Expand Down
28 changes: 21 additions & 7 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,13 @@ def run(cls, job: "Job"):
"""
Procedure for an ML pipeline as a job.
"""
from ami.ml.orchestration.jobs import queue_images_to_nats

job.update_status(JobState.STARTED)
job.started_at = datetime.datetime.now()
job.finished_at = None
job.save()

# Keep track of sub-tasks for saving results, pair with batch number
save_tasks: list[tuple[int, AsyncResult]] = []
save_tasks_completed: list[tuple[int, AsyncResult]] = []

if job.delay:
update_interval_seconds = 2
last_update = time.time()
Expand Down Expand Up @@ -365,7 +363,7 @@ def run(cls, job: "Job"):
progress=0,
)

images = list(
images: list[SourceImage] = list(
# @TODO return generator plus image count
# @TODO pass to celery group chain?
job.pipeline.collect_images(
Expand All @@ -389,8 +387,6 @@ def run(cls, job: "Job"):
images = images[: job.limit]
image_count = len(images)
job.progress.add_stage_param("collect", "Limit", image_count)
else:
image_count = source_image_count

job.progress.update_stage(
"collect",
Expand All @@ -401,6 +397,24 @@ def run(cls, job: "Job"):
# End image collection stage
job.save()

if job.project.feature_flags.async_pipeline_workers:
queued = queue_images_to_nats(job, images)
if not queued:
job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk)
job.progress.update_stage("collect", status=JobState.FAILURE)
job.update_status(JobState.FAILURE)
job.finished_at = datetime.datetime.now()
job.save()
return
else:
cls.process_images(job, images)
Comment on lines +400 to +410
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Async ML path leaves overall job lifecycle undefined (status/finished_at, Celery status mismatch)

When async_pipeline_workers is enabled, MLJob.run queues NATS tasks and returns immediately. The surrounding run_job Celery task then completes and update_job_status will set job.status to the Celery task status (SUCCESS), even though:

  • No images may yet have been processed by workers.
  • finished_at is never set for the async path.
  • The only subsequent updates come from process_pipeline_result, which adjust per-stage progress only (via _update_job_progress) and never touch job.status or finished_at.

Net effect: jobs can show SUCCESS while process/results stages are still <100%, and even once stages reach 100% there is no authoritative completion timestamp or explicit terminal status driven by the async pipeline.

Consider tightening this by:

  • Having the async pipeline be the source of truth for completion, e.g. in _update_job_progress when stage == "results" and progress_percentage >= 1.0, set job.status to JobState.SUCCESS, set finished_at, and (optionally) trigger NATS/Redis cleanup.
  • Optionally, for jobs with async_pipeline_workers=True, avoid overwriting job.status in update_job_status based solely on the run_job Celery task’s status, or treat that status as “queueing succeeded” only.

This will make the async path match the synchronous process_images semantics and avoid confusing “SUCCESS, 0% complete” states.

🤖 Prompt for AI Agents
In ami/jobs/models.py around lines 400-410, the async pipeline path only queues
work and returns, leaving job.status and finished_at unset and allowing the
Celery run_job task to mark the job SUCCESS prematurely; update the logic so
that for async_pipeline_workers you do not overwrite job.status based solely on
the run_job Celery task (treat that status as "queued" or leave unchanged), and
move authoritative completion handling into the async progress handler: in
_update_job_progress, when stage == "results" and progress_percentage >= 1.0 set
job.status = JobState.SUCCESS, set job.finished_at = now(), save the job (and
optionally perform NATS/Redis cleanup), and ensure any queued-path error
handling still sets FAILURE and finished_at as currently done.


@classmethod
def process_images(cls, job, images):
image_count = len(images)
# Keep track of sub-tasks for saving results, pair with batch number
save_tasks: list[tuple[int, AsyncResult]] = []
save_tasks_completed: list[tuple[int, AsyncResult]] = []
total_captures = 0
total_detections = 0
total_classifications = 0
Expand Down
160 changes: 160 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import datetime
import functools
import logging
import time
from collections.abc import Callable

from asgiref.sync import async_to_sync
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction

from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.orchestration.task_state import TaskStateManager
from ami.ml.schemas import PipelineResultsResponse
from ami.tasks import default_soft_time_limit, default_time_limit
from config import celery_app

Expand Down Expand Up @@ -29,6 +38,132 @@ def run_job(self, job_id: int) -> None:
job.logger.info(f"Finished job {job}")


@celery_app.task(
bind=True,
max_retries=0, # don't retry since we already have retry logic in the NATS queue
soft_time_limit=300, # 5 minutes
time_limit=360, # 6 minutes
)
def process_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None:
"""
Process a single pipeline result asynchronously.

This task:
1. Deserializes the pipeline result
2. Saves it to the database
3. Updates progress by removing processed image IDs from Redis
4. Acknowledges the task via NATS

Args:
job_id: The job ID
result_data: Dictionary containing the pipeline result
reply_subject: NATS reply subject for acknowledgment
"""
from ami.jobs.models import Job # avoid circular import

_, t = log_time()
error = result_data.get("error")
pipeline_result = None
if not error:
pipeline_result = PipelineResultsResponse(**result_data)
processed_image_ids = {str(img.id) for img in pipeline_result.source_images}
else:
image_id = result_data.get("image_id")
processed_image_ids = {str(image_id)} if image_id else set()
logger.error(f"Pipeline returned error for job {job_id}, image {image_id}: {error}")

state_manager = TaskStateManager(job_id)

progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)

try:
_update_job_progress(job_id, "process", progress_info.percentage)

_, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%")
job = Job.objects.get(pk=job_id)
job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}")
job.logger.info(
f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed "
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just "
"processed"
)
except Job.DoesNotExist:
# don't raise and ack so that we don't retry since the job doesn't exists
logger.error(f"Job {job_id} not found")
_ack_task_via_nats(reply_subject, logger)
return

try:
# Save to database (this is the slow operation)
if pipeline_result:
# should never happen since otherwise we could not be processing results here
assert job.pipeline is not None, "Job pipeline is None"
job.pipeline.save_results(results=pipeline_result, job_id=job.pk)
job.logger.info(f"Successfully saved results for job {job_id}")

_, t = t(
f"Saved pipeline results to database with {len(pipeline_result.detections)} detections"
f", percentage: {progress_info.percentage*100}%"
)

_ack_task_via_nats(reply_subject, job.logger)
# Update job stage with calculated progress
progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
_update_job_progress(job_id, "results", progress_info.percentage)

except Exception as e:
job.logger.error(f"Failed to process pipeline result for job {job_id}: {e}. Retrying ...")

Comment on lines +102 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Search for save_results method definition to check for idempotency handling
rg -n "def save_results" --type=py -A 20

Repository: RolnickLab/antenna

Length of output: 2730


🏁 Script executed:

# Also search for _ack_task_via_nats to understand what happens on ACK failure
rg -n "_ack_task_via_nats|def _ack_task_via_nats" --type=py -A 15

Repository: RolnickLab/antenna

Length of output: 3185


🏁 Script executed:

# Check the exception handler and retry logic in the task
rg -n "except Exception" ami/jobs/tasks.py -B 5 -A 10

Repository: RolnickLab/antenna

Length of output: 1999


🏁 Script executed:

# Get the full save_results function implementation
sed -n '814,900p' ami/ml/models/pipeline.py

Repository: RolnickLab/antenna

Length of output: 3415


🏁 Script executed:

# Look for Detection.objects.create calls in save_results
rg -n "Detection.*create|bulk_create" ami/ml/models/pipeline.py -B 2 -A 2 | head -50

Repository: RolnickLab/antenna

Length of output: 1139


Confirm: save_results lacks idempotency safeguards and will create duplicate Detection records on NATS redelivery.

The implementation at ami/ml/models/pipeline.py:506 uses Detection.objects.bulk_create() without checking for existing records. If ACK fails after save_results completes (line 107 in tasks.py), the NATS message remains unacknowledged. When JetStream redelivers the message after the ack-wait timeout, the entire task re-executes and calls save_results again with identical data, creating duplicate detections in the database.

Implement idempotency by either:

  • Adding a check to skip detections already in the database before bulk_create
  • Using get_or_create logic per detection
  • Adding unique constraints on Detection fields and using upsert operations

Minor: Clarify the exception handler log message.

Line 128 logs "Retrying..." but the task doesn't call self.retry(). Clarify to "Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the message." to accurately reflect the retry mechanism.

🧰 Tools
🪛 Ruff (0.14.8)

124-124: Abstract raise to an inner function

(TRY301)


127-127: Do not catch blind exception: Exception

(BLE001)


128-128: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
try:

async def ack_task():
async with TaskQueueManager() as manager:
return await manager.acknowledge_task(reply_subject)

ack_success = async_to_sync(ack_task)()

if ack_success:
job_logger.info(f"Successfully acknowledged task via NATS: {reply_subject}")
else:
job_logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}")
except Exception as ack_error:
job_logger.error(f"Error acknowledging task via NATS: {ack_error}")
# Don't fail the task if ACK fails - data is already saved


def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None:
from ami.jobs.models import Job, JobState # avoid circular import

with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
job.progress.update_stage(
stage,
status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED,
progress=progress_percentage,
)
if stage == "results" and progress_percentage >= 1.0:
job.status = JobState.SUCCESS
job.progress.summary.status = JobState.SUCCESS
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()


@task_prerun.connect(sender=run_job)
def pre_update_job_status(sender, task_id, task, **kwargs):
# in the prerun signal, set the job status to PENDING
Expand Down Expand Up @@ -65,3 +200,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs):
job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}')

job.save()


def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]:
"""
Small helper to measure time between calls.

Returns: elapsed time since the last call, and a partial function to measure from the current call
Usage:

_, tlog = log_time()
# do something
_, tlog = tlog("Did something") # will log the time taken by 'something'
# do something else
t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't'
"""

end = time.perf_counter()
if start == 0:
dur = 0.0
else:
dur = end - start
if msg and start > 0:
logger.info(f"{msg}: {dur:.3f}s")
new_start = time.perf_counter()
return dur, functools.partial(log_time, new_start)
15 changes: 12 additions & 3 deletions ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob
from ami.main.models import Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.orchestration.jobs import queue_images_to_nats
from ami.users.models import User

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -326,6 +327,15 @@ def test_search_jobs(self):
def _task_batch_helper(self, value: Any, expected_status: int):
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for batch test", pipeline)
images = [
SourceImage.objects.create(
path=f"image_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(8) # more than 5 since we test with batch=5
]
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
Expand Down Expand Up @@ -390,10 +400,9 @@ def test_result_endpoint_stub(self):

self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["status"], "received")
self.assertEqual(data["status"], "accepted")
self.assertEqual(data["job_id"], job.pk)
self.assertEqual(data["results_received"], 1)
self.assertIn("message", data)
self.assertEqual(data["results_queued"], 1)

def test_result_endpoint_validation(self):
"""Test the result endpoint validates request data."""
Expand Down
Loading
Loading