diff --git a/ami/main/admin.py b/ami/main/admin.py index 215d8a3be..301293957 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -5,13 +5,17 @@ from django.db.models.query import QuerySet from django.http.request import HttpRequest from django.template.defaultfilters import filesizeformat +from django.urls import reverse from django.utils.formats import number_format +from django.utils.html import format_html from guardian.admin import GuardedModelAdmin import ami.utils from ami import tasks from ami.jobs.models import Job +from ami.ml.models.algorithm import Algorithm from ami.ml.models.project_pipeline_config import ProjectPipelineConfig +from ami.ml.post_processing.class_masking import update_single_occurrence from ami.ml.tasks import remove_duplicate_classifications from .models import ( @@ -289,6 +293,7 @@ class ClassificationInline(admin.TabularInline): model = Classification extra = 0 fields = ( + "classification_link", "taxon", "algorithm", "timestamp", @@ -296,6 +301,7 @@ class ClassificationInline(admin.TabularInline): "created_at", ) readonly_fields = ( + "classification_link", "taxon", "algorithm", "timestamp", @@ -303,6 +309,13 @@ class ClassificationInline(admin.TabularInline): "created_at", ) + @admin.display(description="Classification") + def classification_link(self, obj: Classification) -> str: + if obj.pk: + url = reverse("admin:main_classification_change", args=[obj.pk]) + return format_html('{}', url, f"Classification #{obj.pk}") + return "-" + def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) return qs.select_related("taxon", "algorithm", "detection") @@ -312,6 +325,7 @@ class DetectionInline(admin.TabularInline): model = Detection extra = 0 fields = ( + "detection_link", "detection_algorithm", "source_image", "timestamp", @@ -319,6 +333,7 @@ class DetectionInline(admin.TabularInline): "occurrence", ) readonly_fields = ( + "detection_link", "detection_algorithm", "source_image", "timestamp", @@ -326,6 +341,13 @@ class DetectionInline(admin.TabularInline): "occurrence", ) + @admin.display(description="ID") + def detection_link(self, obj): + if obj.pk: + url = reverse("admin:main_detection_change", args=[obj.pk]) + return format_html('{}', url, obj.pk) + return "-" + @admin.register(Detection) class DetectionAdmin(admin.ModelAdmin[Detection]): @@ -383,7 +405,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]): "determination__rank", "created_at", ) - search_fields = ("determination__name", "determination__search_names") + search_fields = ("id", "determination__name", "determination__search_names") def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) @@ -405,11 +427,60 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: def detections_count(self, obj) -> int: return obj.detections_count + @admin.action(description="Update occurrence with Newfoundland species taxa list") + def update_with_newfoundland_species(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> None: + """ + Update selected occurrences using the 'Newfoundland species' taxa list + and 'Quebec & Vermont Species Classifier - Apr 2024' algorithm. + """ + try: + # Get the taxa list by name + taxa_list = TaxaList.objects.get(name="Newfoundland Species") + except TaxaList.DoesNotExist: + self.message_user( + request, + "Error: TaxaList 'Newfoundland species' not found.", + level="error", + ) + return + + try: + # Get the algorithm by name + algorithm = Algorithm.objects.get(name="Quebec & Vermont Species Classifier - Apr 2024") + except Algorithm.DoesNotExist: + self.message_user( + request, + "Error: Algorithm 'Quebec & Vermont Species Classifier - Apr 2024' not found.", + level="error", + ) + return + + # Process each occurrence + count = 0 + for occurrence in queryset: + try: + update_single_occurrence( + occurrence=occurrence, + algorithm=algorithm, + taxa_list=taxa_list, + ) + count += 1 + except Exception as e: + self.message_user( + request, + f"Error processing occurrence {occurrence.pk}: {str(e)}", + level="error", + ) + + self.message_user(request, f"Successfully updated {count} occurrence(s).") + ordering = ("-created_at",) # Add classifications as inline inlines = [DetectionInline] + actions = [update_with_newfoundland_species] + @admin.register(Classification) class ClassificationAdmin(admin.ModelAdmin[Classification]): @@ -433,6 +504,8 @@ class ClassificationAdmin(admin.ModelAdmin[Classification]): "taxon__rank", ) + autocomplete_fields = ("taxon",) + def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) return qs.select_related( @@ -641,10 +714,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI self.message_user(request, f"Queued Small Size Filter for {queryset.count()} collection(s). Jobs: {jobs}") + @admin.action(description="Run Rank Rollup post-processing task (async)") + def run_rank_rollup(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None: + """Trigger the Rank Rollup post-processing job asynchronously.""" + jobs = [] + DEFAULT_THRESHOLDS = {"species": 0.8, "genus": 0.6, "family": 0.4} + + for collection in queryset: + job = Job.objects.create( + name=f"Post-processing: RankRollup on Collection {collection.pk}", + project=collection.project, + job_type_key="post_processing", + params={ + "task": "rank_rollup", + "config": {"source_image_collection_id": collection.pk, "thresholds": DEFAULT_THRESHOLDS}, + }, + ) + job.enqueue() + jobs.append(job.pk) + + self.message_user(request, f"Queued Rank Rollup for {queryset.count()} collection(s). Jobs: {jobs}") + actions = [ populate_collection, populate_collection_async, run_small_size_filter, + run_rank_rollup, ] # Hide images many-to-many field from form. This would list all source images in the database. diff --git a/ami/ml/post_processing/__init__.py b/ami/ml/post_processing/__init__.py index 3517ed47c..e69de29bb 100644 --- a/ami/ml/post_processing/__init__.py +++ b/ami/ml/post_processing/__init__.py @@ -1 +0,0 @@ -from . import small_size_filter # noqa: F401 diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py new file mode 100644 index 000000000..298688ba7 --- /dev/null +++ b/ami/ml/post_processing/class_masking.py @@ -0,0 +1,271 @@ +import logging + +from django.db.models import QuerySet +from django.utils import timezone + +from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap, AlgorithmTaskType +from ami.ml.post_processing.base import BasePostProcessingTask + +logger = logging.getLogger(__name__) + + +def update_single_occurrence( + occurrence: Occurrence, + algorithm: Algorithm, + taxa_list: TaxaList, + task_logger: logging.Logger = logger, +): + task_logger.info(f"Recalculating classifications for occurrence {occurrence.pk}.") + + # Get the classifications for the occurrence in the collection + classifications = Classification.objects.filter( + detection__occurrence=occurrence, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + ).distinct() + + # Make a new Algorithm for the filtered classifications + new_algorithm, _ = Algorithm.objects.get_or_create( + name=f"{algorithm.name} (filtered by taxa list {taxa_list.name})", + key=f"{algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}", + defaults={ + "description": f"Classification algorithm {algorithm.name} filtered by taxa list {taxa_list.name}", + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": algorithm.category_map, + }, + ) + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + new_algorithm=new_algorithm, + ) + + +def update_occurrences_in_collection( + collection: SourceImageCollection, + taxa_list: TaxaList, + algorithm: Algorithm, + params: dict, + new_algorithm: Algorithm, + task_logger: logging.Logger = logger, + job=None, +): + task_logger.info(f"Recalculating classifications based on a taxa list. Params: {params}") + + # Make new AlgorithmCategoryMap with the taxa in the list + # @TODO + + classifications = Classification.objects.filter( + detection__source_image__collections=collection, + terminal=True, + # algorithm__task_type="classification", + algorithm=algorithm, + scores__isnull=False, + ).distinct() + + make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=algorithm, + new_algorithm=new_algorithm, + ) + + +def make_classifications_filtered_by_taxa_list( + classifications: QuerySet[Classification], + taxa_list: TaxaList, + algorithm: Algorithm, + new_algorithm: Algorithm, +): + taxa_in_list = taxa_list.taxa.all() + + occurrences_to_update: set[Occurrence] = set() + logger.info(f"Found {len(classifications)} terminal classifications with scores to update.") + + if not classifications: + raise ValueError("No terminal classifications with scores found to update.") + + if not algorithm.category_map: + raise ValueError(f"Algorithm {algorithm} does not have a category map.") + category_map: AlgorithmCategoryMap = algorithm.category_map + + # Consider moving this to a method on the Classification model + + # @TODO find a more efficient way to get the category map with taxa. This is slow! + logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}") + category_map_with_taxa = category_map.with_taxa() + # Filter the category map to only include taxa that are in the taxa list + # included_category_map_with_taxa = [ + # category for category in category_map_with_taxa if category["taxon"] in taxa_in_list + # ] + excluded_category_map_with_taxa = [ + category for category in category_map_with_taxa if category["taxon"] not in taxa_in_list + ] + + # included_category_indices = [int(category["index"]) for category in category_map_with_taxa] + excluded_category_indices = [ + int(category["index"]) for category in excluded_category_map_with_taxa # type: ignore + ] + + # Log number of categories in the category map, num included, and num excluded, num classifications to update + logger.info( + f"Category map has {len(category_map_with_taxa)} categories, " + f"{len(excluded_category_map_with_taxa)} categories excluded, " + f"{len(classifications)} classifications to check" + ) + + classifications_to_add = [] + classifications_to_update = [] + + timestamp = timezone.now() + for classification in classifications: + scores, logits = classification.scores, classification.logits + # Set scores and logits to zero if they are not in the filtered category indices + + import numpy as np + + # Assert that all scores & logits are lists of numbers + if not isinstance(scores, list) or not all(isinstance(score, (int, float)) for score in scores): + raise ValueError(f"Scores for classification {classification.pk} are not a list of numbers: {scores}") + if not isinstance(logits, list) or not all(isinstance(logit, (int, float)) for logit in logits): + raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") + + logger.debug(f"Processing classification {classification.pk} with {len(scores)} scores") + logger.info(f"Previous totals: {sum(scores)} scores, {sum(logits)} logits") + + # scores_np_filtered = np.array(scores) + logits_np = np.array(logits) + + # scores_np_filtered[excluded_category_indices] = 0.0 + + # @TODO can we use np.NAN instead of 0.0? zero will NOT calculate correctly in softmax. + # @TODO delete the excluded categories from the scores and logits instead of setting to 0.0 + # logits_np[excluded_category_indices] = 0.0 + # logits_np[excluded_category_indices] = np.nan + logits_np[excluded_category_indices] = -100 + + logits: list[float] = logits_np.tolist() + + from numpy import exp + from numpy import sum as np_sum + + # @TODO add test to see if this is correct, or needed! + # Recalculate the softmax scores based on the filtered logits + scores_np: np.ndarray = exp(logits_np - np.max(logits_np)) # Subtract max for numerical stability + scores_np /= np_sum(scores_np) # Normalize to get probabilities + + scores: list = scores_np.tolist() # Convert back to list + + logger.info(f"New totals: {sum(scores)} scores, {sum(logits)} logits") + + # Get the taxon with the highest score using the index of the max score + top_index = scores.index(max(scores)) + top_taxon = category_map_with_taxa[top_index][ + "taxon" + ] # @TODO: This doesn't work if the taxon has never been classified + print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE + print("Top index: ", top_index) # @TODO: REMOVE + + # check if needs updating + if classification.scores == scores and classification.logits == logits: + logger.debug(f"Classification {classification.pk} does not need updating") + continue + + # Consider the existing classification as an intermediate classification + classification.terminal = False + classification.updated_at = timestamp + + # Recalculate the top taxon and score + new_classification = Classification( + taxon=top_taxon, + algorithm=new_algorithm, + score=max(scores), + scores=scores, + logits=logits, + detection=classification.detection, + timestamp=classification.timestamp, + terminal=True, + category_map=None, # @TODO need a new category map with the filtered taxa + created_at=timestamp, + updated_at=timestamp, + ) + if new_classification.taxon is None: + raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully + + classifications_to_update.append(classification) + classifications_to_add.append(new_classification) + + assert new_classification.detection is not None + assert new_classification.detection.occurrence is not None + occurrences_to_update.add(new_classification.detection.occurrence) + + logging.info( + f"Adding new classification for Taxon {top_taxon} to occurrence {new_classification.detection.occurrence}" + ) + + # Bulk update the existing classifications + if classifications_to_update: + logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications") + Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"]) + logger.info(f"Updated {len(classifications_to_update)} existing classifications") + + if classifications_to_add: + # Bulk create the new classifications + logger.info(f"Bulk creating {len(classifications_to_add)} new classifications") + Classification.objects.bulk_create(classifications_to_add) + logger.info(f"Added {len(classifications_to_add)} new classifications") + + # Update the occurrence determinations + logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences") + for occurrence in occurrences_to_update: + occurrence.save(update_determination=True) + logger.info(f"Updated determinations for {len(occurrences_to_update)} occurrences") + + +class ClassMaskingTask(BasePostProcessingTask): + key = "class_masking" + name = "Class masking" + + def run(self) -> None: + """Apply class masking on a source image collection using a taxa list.""" + job = self.job + self.logger.info(f"=== Starting {self.name} ===") + + collection_id = self.config.get("collection_id") + taxa_list_id = self.config.get("taxa_list_id") + algorithm_id = self.config.get("algorithm_id") + + # Validate config parameters + if not all([collection_id, taxa_list_id, algorithm_id]): + self.logger.error("Missing required configuration: collection_id, taxa_list_id, algorithm_id") + return + + try: + collection = SourceImageCollection.objects.get(pk=collection_id) + taxa_list = TaxaList.objects.get(pk=taxa_list_id) + algorithm = Algorithm.objects.get(pk=algorithm_id) + except Exception as e: + self.logger.exception(f"Failed to load objects: {e}") + return + + self.logger.info(f"Applying class masking on collection {collection_id} using taxa list {taxa_list_id}") + + # @TODO temporary, do we need a new algorithm for each class mask? + self.algorithm.category_map = algorithm.category_map # Ensure the algorithm has its category map loaded + + update_occurrences_in_collection( + collection=collection, + taxa_list=taxa_list, + algorithm=algorithm, + params=self.config, + task_logger=self.logger, + job=job, + new_algorithm=self.algorithm, + ) + + self.logger.info("Class masking completed successfully.") + self.logger.info(f"=== Completed {self.name} ===") diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py new file mode 100644 index 000000000..da5177b4e --- /dev/null +++ b/ami/ml/post_processing/rank_rollup.py @@ -0,0 +1,153 @@ +import logging +from collections import defaultdict + +from django.db import transaction +from django.utils import timezone + +from ami.main.models import Classification, Taxon +from ami.ml.post_processing.base import BasePostProcessingTask, register_postprocessing_task + +logger = logging.getLogger(__name__) + + +def find_ancestor_by_parent_chain(taxon, target_rank: str): + """Climb up parent relationships until a taxon with the target rank is found.""" + if not taxon: + return None + + target_rank = target_rank.upper() + + current = taxon + while current: + if current.rank.upper() == target_rank: + return current + current = current.parent + + return None + + +@register_postprocessing_task +class RankRollupTask(BasePostProcessingTask): + """Post-processing task that rolls up low-confidence classifications + to higher ranks using aggregated scores. + """ + + key = "rank_rollup" + name = "Rank rollup" + + DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4} + ROLLUP_ORDER = ["SPECIES", "GENUS", "FAMILY"] + + def run(self) -> None: + job = self.job + self.logger.info(f"Starting {self.name} task for job {job.pk if job else 'N/A'}") + + # ---- Read config parameters ---- + config = self.config or {} + collection_id = config.get("source_image_collection_id") + thresholds = config.get("thresholds", self.DEFAULT_THRESHOLDS) + rollup_order = config.get("rollup_order", self.ROLLUP_ORDER) + + if not collection_id: + self.logger.info("No 'source_image_collection_id' provided in config. Aborting task.") + return + + self.logger.info( + f"Config loaded: collection_id={collection_id}, thresholds={thresholds}, rollup_order={rollup_order}" + ) + + qs = Classification.objects.filter( + terminal=True, + taxon__isnull=False, + detection__source_image__collections__id=collection_id, + ) + + total = qs.count() + self.logger.info(f"Found {total} terminal classifications to process for collection {collection_id}") + + updated_occurrences = [] + + with transaction.atomic(): + for i, clf in enumerate(qs.iterator(), start=1): + self.logger.info(f"Processing classification #{clf.pk} (taxon={clf.taxon}, score={clf.score:.3f})") + + if not clf.scores: + self.logger.info(f"Skipping classification #{clf.pk}: no scores available") + continue + if not clf.category_map: + self.logger.info(f"Skipping classification #{clf.pk}: no category_map assigned") + continue + + taxon_scores = defaultdict(float) + + for idx, score in enumerate(clf.scores): + label = clf.category_map.labels[idx] + if not label: + continue + + taxon = Taxon.objects.filter(name=label).first() + if not taxon: + self.logger.info(f"Skipping label '{label}' (no matching Taxon Found)") + continue + + for rank in rollup_order: + ancestor = find_ancestor_by_parent_chain(taxon, rank) + if ancestor: + taxon_scores[ancestor] += score + self.logger.debug(f" + Added {score:.3f} to ancestor {ancestor.name} ({rank})") + + new_taxon = None + new_score = None + self.logger.info(f"Aggregated taxon scores: { {t.name: s for t, s in taxon_scores.items()} }") + for rank in rollup_order: + threshold = thresholds.get(rank, 1.0) + # import pdb + + # pdb.set_trace() + candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank} + + if not candidates: + self.logger.info(f"No candidates found at rank {rank}") + continue + + best_taxon, best_score = max(candidates.items(), key=lambda kv: kv[1]) + self.logger.info( + f"Best at rank {rank}: {best_taxon.name} ({best_score:.3f}) [threshold={threshold}]" + ) + + if best_score >= threshold: + new_taxon, new_score = best_taxon, best_score + self.logger.info(f"Rollup decision: {new_taxon.name} ({rank}) with score {new_score:.3f}") + break + + if new_taxon and new_taxon != clf.taxon: + self.logger.info(f"Rolling up {clf.taxon} => {new_taxon} ({new_taxon.rank})") + + # Mark all classifications for this detection as non-terminal + Classification.objects.filter(detection=clf.detection).update(terminal=False) + Classification.objects.create( + detection=clf.detection, + taxon=new_taxon, + score=new_score, + terminal=True, + algorithm=self.algorithm, + timestamp=timezone.now(), + applied_to=clf, + ) + + occurrence = clf.detection.occurrence + updated_occurrences.append(occurrence) + self.logger.info( + f"Rolled up occurrence {occurrence.pk}: {clf.taxon} => {new_taxon} " + f"({new_taxon.rank}) with rolled-up score={new_score:.3f}" + ) + else: + self.logger.info(f"No rollup applied for classification #{clf.pk} (taxon={clf.taxon})") + + # Update progress every 10 iterations + if i % 10 == 0 or i == total: + progress = i / total if total > 0 else 1.0 + self.update_progress(progress) + + self.logger.info(f"Rank rollup completed. Updated {len(updated_occurrences)} occurrences.") + self.logger.info(f"{self.name} task finished for collection {collection_id}.") diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py index c85f607f9..308be18ae 100644 --- a/ami/ml/post_processing/registry.py +++ b/ami/ml/post_processing/registry.py @@ -1,8 +1,10 @@ # Registry of available post-processing tasks +from ami.ml.post_processing.class_masking import ClassMaskingTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask POSTPROCESSING_TASKS = { SmallSizeFilterTask.key: SmallSizeFilterTask, + ClassMaskingTask.key: ClassMaskingTask, } diff --git a/ami/ml/tests.py b/ami/ml/tests.py index f88bfbbc0..6966567bc 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -10,14 +10,18 @@ from ami.main.models import ( Classification, Detection, + Occurrence, Project, SourceImage, SourceImageCollection, Taxon, + TaxonRank, group_images_into_events, ) -from ami.ml.models import Algorithm, Pipeline, ProcessingService +from ami.ml.models import Algorithm, AlgorithmCategoryMap, Pipeline, ProcessingService +from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results +from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask from ami.ml.schemas import ( AlgorithmConfigResponse, @@ -748,6 +752,13 @@ def setUp(self): ) self.collection.populate_sample() + # Select example taxa + self.species_taxon = Taxon.objects.filter(rank=TaxonRank.SPECIES.name).first() + self.genus_taxon = self.species_taxon.parent if self.species_taxon else None + self.assertIsNotNone(self.species_taxon) + self.assertIsNotNone(self.genus_taxon) + self.algorithm = self._create_category_map_with_algorithm() + def _create_images_with_dimensions( self, deployment, @@ -813,7 +824,9 @@ def test_small_size_filter_assigns_not_identifiable(self): not_identifiable_taxon, f"Detection {det.pk} should be classified as 'Not identifiable'", ) + occurrence = det.occurrence + assert occurrence self.assertIsNotNone(occurrence, f"Detection {det.pk} should belong to an occurrence.") occurrence.refresh_from_db() self.assertEqual( @@ -821,3 +834,95 @@ def test_small_size_filter_assigns_not_identifiable(self): not_identifiable_taxon, f"Occurrence {occurrence.pk} should have its determination set to 'Not identifiable'.", ) + + def _create_occurrences_with_classifications(self, num=3): + """Helper to create occurrences and terminal classifications below species threshold.""" + occurrences = [] + now = datetime.datetime.now(datetime.timezone.utc) + for i in range(num): + det = Detection.objects.create( + source_image=self.collection.images.first(), + bbox=[0, 0, 200, 200], + ) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + classification = Classification.objects.create( + detection=det, + taxon=self.species_taxon, + score=0.5, + scores=[0.5, 0.3, 0.2], + terminal=True, + timestamp=now, + algorithm=self.algorithm, + ) + occurrences.append((occ, classification)) + return occurrences + + def _create_category_map_with_algorithm(self): + """Create a simple AlgorithmCategoryMap and Algorithm to attach to classifications.""" + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name)[:3]) + assert species_taxa, "No species taxa found in project; run create_taxa() first." + + data = [ + { + "index": i, + "label": taxon.name, + "taxon_rank": taxon.rank, + "gbif_key": getattr(taxon, "gbif_key", None), + } + for i, taxon in enumerate(species_taxa) + ] + labels = [item["label"] for item in data] + + category_map = AlgorithmCategoryMap.objects.create( + data=data, + labels=labels, + version="v1.0", + description="Species-level category map for testing RankRollupTask", + ) + + algorithm = Algorithm.objects.create( + name="Test Species Classifier", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=category_map, + ) + + return algorithm + + def test_rank_rollup_creates_new_terminal_classifications(self): + occurrences = self._create_occurrences_with_classifications(num=3) + + task = RankRollupTask( + source_image_collection_id=self.collection.pk, + thresholds={"species": 0.8, "genus": 0.6, "family": 0.4}, + ) + task.run() + + # Validate results + for occ, original_cls in occurrences: + detection = occ.detections.first() + original_cls.refresh_from_db(fields=["terminal"]) + rolled_up_cls = Classification.objects.filter(detection=detection, terminal=True).first() + + self.assertIsNotNone( + rolled_up_cls, + f"Expected a new rolled-up classification for original #{original_cls.pk}", + ) + self.assertTrue( + rolled_up_cls.terminal, + "New rolled-up classification should be marked as terminal.", + ) + self.assertFalse( + original_cls.terminal, + "Original classification should be marked as non-terminal after roll-up.", + ) + self.assertEqual( + rolled_up_cls.taxon, + self.genus_taxon, + "Rolled-up classification should have genus-level taxon.", + ) + self.assertEqual( + rolled_up_cls.applied_to, + original_cls, + "Rolled-up classification should reference the original classification.", + )