diff --git a/ami/main/admin.py b/ami/main/admin.py
index 215d8a3be..301293957 100644
--- a/ami/main/admin.py
+++ b/ami/main/admin.py
@@ -5,13 +5,17 @@
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.template.defaultfilters import filesizeformat
+from django.urls import reverse
from django.utils.formats import number_format
+from django.utils.html import format_html
from guardian.admin import GuardedModelAdmin
import ami.utils
from ami import tasks
from ami.jobs.models import Job
+from ami.ml.models.algorithm import Algorithm
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
+from ami.ml.post_processing.class_masking import update_single_occurrence
from ami.ml.tasks import remove_duplicate_classifications
from .models import (
@@ -289,6 +293,7 @@ class ClassificationInline(admin.TabularInline):
model = Classification
extra = 0
fields = (
+ "classification_link",
"taxon",
"algorithm",
"timestamp",
@@ -296,6 +301,7 @@ class ClassificationInline(admin.TabularInline):
"created_at",
)
readonly_fields = (
+ "classification_link",
"taxon",
"algorithm",
"timestamp",
@@ -303,6 +309,13 @@ class ClassificationInline(admin.TabularInline):
"created_at",
)
+ @admin.display(description="Classification")
+ def classification_link(self, obj: Classification) -> str:
+ if obj.pk:
+ url = reverse("admin:main_classification_change", args=[obj.pk])
+ return format_html('{}', url, f"Classification #{obj.pk}")
+ return "-"
+
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related("taxon", "algorithm", "detection")
@@ -312,6 +325,7 @@ class DetectionInline(admin.TabularInline):
model = Detection
extra = 0
fields = (
+ "detection_link",
"detection_algorithm",
"source_image",
"timestamp",
@@ -319,6 +333,7 @@ class DetectionInline(admin.TabularInline):
"occurrence",
)
readonly_fields = (
+ "detection_link",
"detection_algorithm",
"source_image",
"timestamp",
@@ -326,6 +341,13 @@ class DetectionInline(admin.TabularInline):
"occurrence",
)
+ @admin.display(description="ID")
+ def detection_link(self, obj):
+ if obj.pk:
+ url = reverse("admin:main_detection_change", args=[obj.pk])
+ return format_html('{}', url, obj.pk)
+ return "-"
+
@admin.register(Detection)
class DetectionAdmin(admin.ModelAdmin[Detection]):
@@ -383,7 +405,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]):
"determination__rank",
"created_at",
)
- search_fields = ("determination__name", "determination__search_names")
+ search_fields = ("id", "determination__name", "determination__search_names")
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
@@ -405,11 +427,60 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
def detections_count(self, obj) -> int:
return obj.detections_count
+ @admin.action(description="Update occurrence with Newfoundland species taxa list")
+ def update_with_newfoundland_species(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> None:
+ """
+ Update selected occurrences using the 'Newfoundland species' taxa list
+ and 'Quebec & Vermont Species Classifier - Apr 2024' algorithm.
+ """
+ try:
+ # Get the taxa list by name
+ taxa_list = TaxaList.objects.get(name="Newfoundland Species")
+ except TaxaList.DoesNotExist:
+ self.message_user(
+ request,
+ "Error: TaxaList 'Newfoundland species' not found.",
+ level="error",
+ )
+ return
+
+ try:
+ # Get the algorithm by name
+ algorithm = Algorithm.objects.get(name="Quebec & Vermont Species Classifier - Apr 2024")
+ except Algorithm.DoesNotExist:
+ self.message_user(
+ request,
+ "Error: Algorithm 'Quebec & Vermont Species Classifier - Apr 2024' not found.",
+ level="error",
+ )
+ return
+
+ # Process each occurrence
+ count = 0
+ for occurrence in queryset:
+ try:
+ update_single_occurrence(
+ occurrence=occurrence,
+ algorithm=algorithm,
+ taxa_list=taxa_list,
+ )
+ count += 1
+ except Exception as e:
+ self.message_user(
+ request,
+ f"Error processing occurrence {occurrence.pk}: {str(e)}",
+ level="error",
+ )
+
+ self.message_user(request, f"Successfully updated {count} occurrence(s).")
+
ordering = ("-created_at",)
# Add classifications as inline
inlines = [DetectionInline]
+ actions = [update_with_newfoundland_species]
+
@admin.register(Classification)
class ClassificationAdmin(admin.ModelAdmin[Classification]):
@@ -433,6 +504,8 @@ class ClassificationAdmin(admin.ModelAdmin[Classification]):
"taxon__rank",
)
+ autocomplete_fields = ("taxon",)
+
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related(
@@ -641,10 +714,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI
self.message_user(request, f"Queued Small Size Filter for {queryset.count()} collection(s). Jobs: {jobs}")
+ @admin.action(description="Run Rank Rollup post-processing task (async)")
+ def run_rank_rollup(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
+ """Trigger the Rank Rollup post-processing job asynchronously."""
+ jobs = []
+ DEFAULT_THRESHOLDS = {"species": 0.8, "genus": 0.6, "family": 0.4}
+
+ for collection in queryset:
+ job = Job.objects.create(
+ name=f"Post-processing: RankRollup on Collection {collection.pk}",
+ project=collection.project,
+ job_type_key="post_processing",
+ params={
+ "task": "rank_rollup",
+ "config": {"source_image_collection_id": collection.pk, "thresholds": DEFAULT_THRESHOLDS},
+ },
+ )
+ job.enqueue()
+ jobs.append(job.pk)
+
+ self.message_user(request, f"Queued Rank Rollup for {queryset.count()} collection(s). Jobs: {jobs}")
+
actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
+ run_rank_rollup,
]
# Hide images many-to-many field from form. This would list all source images in the database.
diff --git a/ami/ml/post_processing/__init__.py b/ami/ml/post_processing/__init__.py
index 3517ed47c..e69de29bb 100644
--- a/ami/ml/post_processing/__init__.py
+++ b/ami/ml/post_processing/__init__.py
@@ -1 +0,0 @@
-from . import small_size_filter # noqa: F401
diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py
new file mode 100644
index 000000000..298688ba7
--- /dev/null
+++ b/ami/ml/post_processing/class_masking.py
@@ -0,0 +1,271 @@
+import logging
+
+from django.db.models import QuerySet
+from django.utils import timezone
+
+from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList
+from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap, AlgorithmTaskType
+from ami.ml.post_processing.base import BasePostProcessingTask
+
+logger = logging.getLogger(__name__)
+
+
+def update_single_occurrence(
+ occurrence: Occurrence,
+ algorithm: Algorithm,
+ taxa_list: TaxaList,
+ task_logger: logging.Logger = logger,
+):
+ task_logger.info(f"Recalculating classifications for occurrence {occurrence.pk}.")
+
+ # Get the classifications for the occurrence in the collection
+ classifications = Classification.objects.filter(
+ detection__occurrence=occurrence,
+ terminal=True,
+ algorithm=algorithm,
+ scores__isnull=False,
+ ).distinct()
+
+ # Make a new Algorithm for the filtered classifications
+ new_algorithm, _ = Algorithm.objects.get_or_create(
+ name=f"{algorithm.name} (filtered by taxa list {taxa_list.name})",
+ key=f"{algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}",
+ defaults={
+ "description": f"Classification algorithm {algorithm.name} filtered by taxa list {taxa_list.name}",
+ "task_type": AlgorithmTaskType.CLASSIFICATION.value,
+ "category_map": algorithm.category_map,
+ },
+ )
+
+ make_classifications_filtered_by_taxa_list(
+ classifications=classifications,
+ taxa_list=taxa_list,
+ algorithm=algorithm,
+ new_algorithm=new_algorithm,
+ )
+
+
+def update_occurrences_in_collection(
+ collection: SourceImageCollection,
+ taxa_list: TaxaList,
+ algorithm: Algorithm,
+ params: dict,
+ new_algorithm: Algorithm,
+ task_logger: logging.Logger = logger,
+ job=None,
+):
+ task_logger.info(f"Recalculating classifications based on a taxa list. Params: {params}")
+
+ # Make new AlgorithmCategoryMap with the taxa in the list
+ # @TODO
+
+ classifications = Classification.objects.filter(
+ detection__source_image__collections=collection,
+ terminal=True,
+ # algorithm__task_type="classification",
+ algorithm=algorithm,
+ scores__isnull=False,
+ ).distinct()
+
+ make_classifications_filtered_by_taxa_list(
+ classifications=classifications,
+ taxa_list=taxa_list,
+ algorithm=algorithm,
+ new_algorithm=new_algorithm,
+ )
+
+
+def make_classifications_filtered_by_taxa_list(
+ classifications: QuerySet[Classification],
+ taxa_list: TaxaList,
+ algorithm: Algorithm,
+ new_algorithm: Algorithm,
+):
+ taxa_in_list = taxa_list.taxa.all()
+
+ occurrences_to_update: set[Occurrence] = set()
+ logger.info(f"Found {len(classifications)} terminal classifications with scores to update.")
+
+ if not classifications:
+ raise ValueError("No terminal classifications with scores found to update.")
+
+ if not algorithm.category_map:
+ raise ValueError(f"Algorithm {algorithm} does not have a category map.")
+ category_map: AlgorithmCategoryMap = algorithm.category_map
+
+ # Consider moving this to a method on the Classification model
+
+ # @TODO find a more efficient way to get the category map with taxa. This is slow!
+ logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}")
+ category_map_with_taxa = category_map.with_taxa()
+ # Filter the category map to only include taxa that are in the taxa list
+ # included_category_map_with_taxa = [
+ # category for category in category_map_with_taxa if category["taxon"] in taxa_in_list
+ # ]
+ excluded_category_map_with_taxa = [
+ category for category in category_map_with_taxa if category["taxon"] not in taxa_in_list
+ ]
+
+ # included_category_indices = [int(category["index"]) for category in category_map_with_taxa]
+ excluded_category_indices = [
+ int(category["index"]) for category in excluded_category_map_with_taxa # type: ignore
+ ]
+
+ # Log number of categories in the category map, num included, and num excluded, num classifications to update
+ logger.info(
+ f"Category map has {len(category_map_with_taxa)} categories, "
+ f"{len(excluded_category_map_with_taxa)} categories excluded, "
+ f"{len(classifications)} classifications to check"
+ )
+
+ classifications_to_add = []
+ classifications_to_update = []
+
+ timestamp = timezone.now()
+ for classification in classifications:
+ scores, logits = classification.scores, classification.logits
+ # Set scores and logits to zero if they are not in the filtered category indices
+
+ import numpy as np
+
+ # Assert that all scores & logits are lists of numbers
+ if not isinstance(scores, list) or not all(isinstance(score, (int, float)) for score in scores):
+ raise ValueError(f"Scores for classification {classification.pk} are not a list of numbers: {scores}")
+ if not isinstance(logits, list) or not all(isinstance(logit, (int, float)) for logit in logits):
+ raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}")
+
+ logger.debug(f"Processing classification {classification.pk} with {len(scores)} scores")
+ logger.info(f"Previous totals: {sum(scores)} scores, {sum(logits)} logits")
+
+ # scores_np_filtered = np.array(scores)
+ logits_np = np.array(logits)
+
+ # scores_np_filtered[excluded_category_indices] = 0.0
+
+ # @TODO can we use np.NAN instead of 0.0? zero will NOT calculate correctly in softmax.
+ # @TODO delete the excluded categories from the scores and logits instead of setting to 0.0
+ # logits_np[excluded_category_indices] = 0.0
+ # logits_np[excluded_category_indices] = np.nan
+ logits_np[excluded_category_indices] = -100
+
+ logits: list[float] = logits_np.tolist()
+
+ from numpy import exp
+ from numpy import sum as np_sum
+
+ # @TODO add test to see if this is correct, or needed!
+ # Recalculate the softmax scores based on the filtered logits
+ scores_np: np.ndarray = exp(logits_np - np.max(logits_np)) # Subtract max for numerical stability
+ scores_np /= np_sum(scores_np) # Normalize to get probabilities
+
+ scores: list = scores_np.tolist() # Convert back to list
+
+ logger.info(f"New totals: {sum(scores)} scores, {sum(logits)} logits")
+
+ # Get the taxon with the highest score using the index of the max score
+ top_index = scores.index(max(scores))
+ top_taxon = category_map_with_taxa[top_index][
+ "taxon"
+ ] # @TODO: This doesn't work if the taxon has never been classified
+ print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE
+ print("Top index: ", top_index) # @TODO: REMOVE
+
+ # check if needs updating
+ if classification.scores == scores and classification.logits == logits:
+ logger.debug(f"Classification {classification.pk} does not need updating")
+ continue
+
+ # Consider the existing classification as an intermediate classification
+ classification.terminal = False
+ classification.updated_at = timestamp
+
+ # Recalculate the top taxon and score
+ new_classification = Classification(
+ taxon=top_taxon,
+ algorithm=new_algorithm,
+ score=max(scores),
+ scores=scores,
+ logits=logits,
+ detection=classification.detection,
+ timestamp=classification.timestamp,
+ terminal=True,
+ category_map=None, # @TODO need a new category map with the filtered taxa
+ created_at=timestamp,
+ updated_at=timestamp,
+ )
+ if new_classification.taxon is None:
+ raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully
+
+ classifications_to_update.append(classification)
+ classifications_to_add.append(new_classification)
+
+ assert new_classification.detection is not None
+ assert new_classification.detection.occurrence is not None
+ occurrences_to_update.add(new_classification.detection.occurrence)
+
+ logging.info(
+ f"Adding new classification for Taxon {top_taxon} to occurrence {new_classification.detection.occurrence}"
+ )
+
+ # Bulk update the existing classifications
+ if classifications_to_update:
+ logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications")
+ Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"])
+ logger.info(f"Updated {len(classifications_to_update)} existing classifications")
+
+ if classifications_to_add:
+ # Bulk create the new classifications
+ logger.info(f"Bulk creating {len(classifications_to_add)} new classifications")
+ Classification.objects.bulk_create(classifications_to_add)
+ logger.info(f"Added {len(classifications_to_add)} new classifications")
+
+ # Update the occurrence determinations
+ logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences")
+ for occurrence in occurrences_to_update:
+ occurrence.save(update_determination=True)
+ logger.info(f"Updated determinations for {len(occurrences_to_update)} occurrences")
+
+
+class ClassMaskingTask(BasePostProcessingTask):
+ key = "class_masking"
+ name = "Class masking"
+
+ def run(self) -> None:
+ """Apply class masking on a source image collection using a taxa list."""
+ job = self.job
+ self.logger.info(f"=== Starting {self.name} ===")
+
+ collection_id = self.config.get("collection_id")
+ taxa_list_id = self.config.get("taxa_list_id")
+ algorithm_id = self.config.get("algorithm_id")
+
+ # Validate config parameters
+ if not all([collection_id, taxa_list_id, algorithm_id]):
+ self.logger.error("Missing required configuration: collection_id, taxa_list_id, algorithm_id")
+ return
+
+ try:
+ collection = SourceImageCollection.objects.get(pk=collection_id)
+ taxa_list = TaxaList.objects.get(pk=taxa_list_id)
+ algorithm = Algorithm.objects.get(pk=algorithm_id)
+ except Exception as e:
+ self.logger.exception(f"Failed to load objects: {e}")
+ return
+
+ self.logger.info(f"Applying class masking on collection {collection_id} using taxa list {taxa_list_id}")
+
+ # @TODO temporary, do we need a new algorithm for each class mask?
+ self.algorithm.category_map = algorithm.category_map # Ensure the algorithm has its category map loaded
+
+ update_occurrences_in_collection(
+ collection=collection,
+ taxa_list=taxa_list,
+ algorithm=algorithm,
+ params=self.config,
+ task_logger=self.logger,
+ job=job,
+ new_algorithm=self.algorithm,
+ )
+
+ self.logger.info("Class masking completed successfully.")
+ self.logger.info(f"=== Completed {self.name} ===")
diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py
new file mode 100644
index 000000000..da5177b4e
--- /dev/null
+++ b/ami/ml/post_processing/rank_rollup.py
@@ -0,0 +1,153 @@
+import logging
+from collections import defaultdict
+
+from django.db import transaction
+from django.utils import timezone
+
+from ami.main.models import Classification, Taxon
+from ami.ml.post_processing.base import BasePostProcessingTask, register_postprocessing_task
+
+logger = logging.getLogger(__name__)
+
+
+def find_ancestor_by_parent_chain(taxon, target_rank: str):
+ """Climb up parent relationships until a taxon with the target rank is found."""
+ if not taxon:
+ return None
+
+ target_rank = target_rank.upper()
+
+ current = taxon
+ while current:
+ if current.rank.upper() == target_rank:
+ return current
+ current = current.parent
+
+ return None
+
+
+@register_postprocessing_task
+class RankRollupTask(BasePostProcessingTask):
+ """Post-processing task that rolls up low-confidence classifications
+ to higher ranks using aggregated scores.
+ """
+
+ key = "rank_rollup"
+ name = "Rank rollup"
+
+ DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4}
+ ROLLUP_ORDER = ["SPECIES", "GENUS", "FAMILY"]
+
+ def run(self) -> None:
+ job = self.job
+ self.logger.info(f"Starting {self.name} task for job {job.pk if job else 'N/A'}")
+
+ # ---- Read config parameters ----
+ config = self.config or {}
+ collection_id = config.get("source_image_collection_id")
+ thresholds = config.get("thresholds", self.DEFAULT_THRESHOLDS)
+ rollup_order = config.get("rollup_order", self.ROLLUP_ORDER)
+
+ if not collection_id:
+ self.logger.info("No 'source_image_collection_id' provided in config. Aborting task.")
+ return
+
+ self.logger.info(
+ f"Config loaded: collection_id={collection_id}, thresholds={thresholds}, rollup_order={rollup_order}"
+ )
+
+ qs = Classification.objects.filter(
+ terminal=True,
+ taxon__isnull=False,
+ detection__source_image__collections__id=collection_id,
+ )
+
+ total = qs.count()
+ self.logger.info(f"Found {total} terminal classifications to process for collection {collection_id}")
+
+ updated_occurrences = []
+
+ with transaction.atomic():
+ for i, clf in enumerate(qs.iterator(), start=1):
+ self.logger.info(f"Processing classification #{clf.pk} (taxon={clf.taxon}, score={clf.score:.3f})")
+
+ if not clf.scores:
+ self.logger.info(f"Skipping classification #{clf.pk}: no scores available")
+ continue
+ if not clf.category_map:
+ self.logger.info(f"Skipping classification #{clf.pk}: no category_map assigned")
+ continue
+
+ taxon_scores = defaultdict(float)
+
+ for idx, score in enumerate(clf.scores):
+ label = clf.category_map.labels[idx]
+ if not label:
+ continue
+
+ taxon = Taxon.objects.filter(name=label).first()
+ if not taxon:
+ self.logger.info(f"Skipping label '{label}' (no matching Taxon Found)")
+ continue
+
+ for rank in rollup_order:
+ ancestor = find_ancestor_by_parent_chain(taxon, rank)
+ if ancestor:
+ taxon_scores[ancestor] += score
+ self.logger.debug(f" + Added {score:.3f} to ancestor {ancestor.name} ({rank})")
+
+ new_taxon = None
+ new_score = None
+ self.logger.info(f"Aggregated taxon scores: { {t.name: s for t, s in taxon_scores.items()} }")
+ for rank in rollup_order:
+ threshold = thresholds.get(rank, 1.0)
+ # import pdb
+
+ # pdb.set_trace()
+ candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank}
+
+ if not candidates:
+ self.logger.info(f"No candidates found at rank {rank}")
+ continue
+
+ best_taxon, best_score = max(candidates.items(), key=lambda kv: kv[1])
+ self.logger.info(
+ f"Best at rank {rank}: {best_taxon.name} ({best_score:.3f}) [threshold={threshold}]"
+ )
+
+ if best_score >= threshold:
+ new_taxon, new_score = best_taxon, best_score
+ self.logger.info(f"Rollup decision: {new_taxon.name} ({rank}) with score {new_score:.3f}")
+ break
+
+ if new_taxon and new_taxon != clf.taxon:
+ self.logger.info(f"Rolling up {clf.taxon} => {new_taxon} ({new_taxon.rank})")
+
+ # Mark all classifications for this detection as non-terminal
+ Classification.objects.filter(detection=clf.detection).update(terminal=False)
+ Classification.objects.create(
+ detection=clf.detection,
+ taxon=new_taxon,
+ score=new_score,
+ terminal=True,
+ algorithm=self.algorithm,
+ timestamp=timezone.now(),
+ applied_to=clf,
+ )
+
+ occurrence = clf.detection.occurrence
+ updated_occurrences.append(occurrence)
+ self.logger.info(
+ f"Rolled up occurrence {occurrence.pk}: {clf.taxon} => {new_taxon} "
+ f"({new_taxon.rank}) with rolled-up score={new_score:.3f}"
+ )
+ else:
+ self.logger.info(f"No rollup applied for classification #{clf.pk} (taxon={clf.taxon})")
+
+ # Update progress every 10 iterations
+ if i % 10 == 0 or i == total:
+ progress = i / total if total > 0 else 1.0
+ self.update_progress(progress)
+
+ self.logger.info(f"Rank rollup completed. Updated {len(updated_occurrences)} occurrences.")
+ self.logger.info(f"{self.name} task finished for collection {collection_id}.")
diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py
index c85f607f9..308be18ae 100644
--- a/ami/ml/post_processing/registry.py
+++ b/ami/ml/post_processing/registry.py
@@ -1,8 +1,10 @@
# Registry of available post-processing tasks
+from ami.ml.post_processing.class_masking import ClassMaskingTask
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
POSTPROCESSING_TASKS = {
SmallSizeFilterTask.key: SmallSizeFilterTask,
+ ClassMaskingTask.key: ClassMaskingTask,
}
diff --git a/ami/ml/tests.py b/ami/ml/tests.py
index f88bfbbc0..6966567bc 100644
--- a/ami/ml/tests.py
+++ b/ami/ml/tests.py
@@ -10,14 +10,18 @@
from ami.main.models import (
Classification,
Detection,
+ Occurrence,
Project,
SourceImage,
SourceImageCollection,
Taxon,
+ TaxonRank,
group_images_into_events,
)
-from ami.ml.models import Algorithm, Pipeline, ProcessingService
+from ami.ml.models import Algorithm, AlgorithmCategoryMap, Pipeline, ProcessingService
+from ami.ml.models.algorithm import AlgorithmTaskType
from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results
+from ami.ml.post_processing.rank_rollup import RankRollupTask
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
from ami.ml.schemas import (
AlgorithmConfigResponse,
@@ -748,6 +752,13 @@ def setUp(self):
)
self.collection.populate_sample()
+ # Select example taxa
+ self.species_taxon = Taxon.objects.filter(rank=TaxonRank.SPECIES.name).first()
+ self.genus_taxon = self.species_taxon.parent if self.species_taxon else None
+ self.assertIsNotNone(self.species_taxon)
+ self.assertIsNotNone(self.genus_taxon)
+ self.algorithm = self._create_category_map_with_algorithm()
+
def _create_images_with_dimensions(
self,
deployment,
@@ -813,7 +824,9 @@ def test_small_size_filter_assigns_not_identifiable(self):
not_identifiable_taxon,
f"Detection {det.pk} should be classified as 'Not identifiable'",
)
+
occurrence = det.occurrence
+ assert occurrence
self.assertIsNotNone(occurrence, f"Detection {det.pk} should belong to an occurrence.")
occurrence.refresh_from_db()
self.assertEqual(
@@ -821,3 +834,95 @@ def test_small_size_filter_assigns_not_identifiable(self):
not_identifiable_taxon,
f"Occurrence {occurrence.pk} should have its determination set to 'Not identifiable'.",
)
+
+ def _create_occurrences_with_classifications(self, num=3):
+ """Helper to create occurrences and terminal classifications below species threshold."""
+ occurrences = []
+ now = datetime.datetime.now(datetime.timezone.utc)
+ for i in range(num):
+ det = Detection.objects.create(
+ source_image=self.collection.images.first(),
+ bbox=[0, 0, 200, 200],
+ )
+ occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first())
+ occ.detections.add(det)
+ classification = Classification.objects.create(
+ detection=det,
+ taxon=self.species_taxon,
+ score=0.5,
+ scores=[0.5, 0.3, 0.2],
+ terminal=True,
+ timestamp=now,
+ algorithm=self.algorithm,
+ )
+ occurrences.append((occ, classification))
+ return occurrences
+
+ def _create_category_map_with_algorithm(self):
+ """Create a simple AlgorithmCategoryMap and Algorithm to attach to classifications."""
+ species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name)[:3])
+ assert species_taxa, "No species taxa found in project; run create_taxa() first."
+
+ data = [
+ {
+ "index": i,
+ "label": taxon.name,
+ "taxon_rank": taxon.rank,
+ "gbif_key": getattr(taxon, "gbif_key", None),
+ }
+ for i, taxon in enumerate(species_taxa)
+ ]
+ labels = [item["label"] for item in data]
+
+ category_map = AlgorithmCategoryMap.objects.create(
+ data=data,
+ labels=labels,
+ version="v1.0",
+ description="Species-level category map for testing RankRollupTask",
+ )
+
+ algorithm = Algorithm.objects.create(
+ name="Test Species Classifier",
+ task_type=AlgorithmTaskType.CLASSIFICATION.value,
+ category_map=category_map,
+ )
+
+ return algorithm
+
+ def test_rank_rollup_creates_new_terminal_classifications(self):
+ occurrences = self._create_occurrences_with_classifications(num=3)
+
+ task = RankRollupTask(
+ source_image_collection_id=self.collection.pk,
+ thresholds={"species": 0.8, "genus": 0.6, "family": 0.4},
+ )
+ task.run()
+
+ # Validate results
+ for occ, original_cls in occurrences:
+ detection = occ.detections.first()
+ original_cls.refresh_from_db(fields=["terminal"])
+ rolled_up_cls = Classification.objects.filter(detection=detection, terminal=True).first()
+
+ self.assertIsNotNone(
+ rolled_up_cls,
+ f"Expected a new rolled-up classification for original #{original_cls.pk}",
+ )
+ self.assertTrue(
+ rolled_up_cls.terminal,
+ "New rolled-up classification should be marked as terminal.",
+ )
+ self.assertFalse(
+ original_cls.terminal,
+ "Original classification should be marked as non-terminal after roll-up.",
+ )
+ self.assertEqual(
+ rolled_up_cls.taxon,
+ self.genus_taxon,
+ "Rolled-up classification should have genus-level taxon.",
+ )
+ self.assertEqual(
+ rolled_up_cls.applied_to,
+ original_cls,
+ "Rolled-up classification should reference the original classification.",
+ )