From b26a3136f12f18173896e2cdb9568e3e88200688 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 22 Jun 2026 16:59:35 -0700 Subject: [PATCH 1/8] feat(post-processing): implement class masking + rank rollup on the admin framework Rework of #999 onto the #1289 post-processing framework. The branch was cut from an old main with its own hand-rolled admin action; ClassMaskingTask and RankRollupTask now subclass BasePostProcessingTask with pydantic config schemas and are triggered through make_post_processing_action (collection scope on SourceImageCollection, single-occurrence scope on Occurrence for class masking). Correctness fixes from the review threads: - Class masking selects the top class from an -inf-masked softmax, so a class excluded by the taxa list can never win even when it had the highest logit; raises when the taxa list excludes every class in the category map. Stored logits stay raw (JSON-safe) and the mask is captured in scores (excluded -> 0). - The masked-output Algorithm is one per (source algorithm, taxa list) and its category map is persisted (previously set in memory only, so masked classifications referenced a null map). - applied_to is populated on new masked classifications (the provenance the API exposes was left blank). - Rank rollup preloads category-map labels in two queries and select_relates the per-row relations instead of dereferencing category_map per classification. Surfaces provenance in the API: applied_to is added to the Classification serializers, and applied_to__algorithm is prefetched in the occurrence list/detail prefetch and the classification viewset to avoid an N+1 on render. Tests: pydantic config validation, admin trigger for both scopes, the masking maths (including the excluded-class guarantee and the all-excluded error), ClassMaskingTask.run() end to end for both scopes, and rank rollup. 20 new tests; full post_processing suite and occurrence query-count tests pass. Co-Authored-By: Claude --- ami/main/admin.py | 30 +- ami/main/api/serializers.py | 21 ++ ami/main/api/views.py | 2 +- ami/main/models_future/occurrence.py | 5 +- .../management/commands/run_class_masking.py | 83 +++++ ami/ml/post_processing/__init__.py | 2 + .../admin/class_masking_form.py | 38 +++ .../post_processing/admin/rank_rollup_form.py | 13 + ami/ml/post_processing/class_masking.py | 267 ++++++++++++++++ ami/ml/post_processing/rank_rollup.py | 205 ++++++++++++ ami/ml/post_processing/registry.py | 4 + .../tests/test_class_masking.py | 302 ++++++++++++++++++ .../tests/test_class_masking_admin.py | 150 +++++++++ 13 files changed, 1119 insertions(+), 3 deletions(-) create mode 100644 ami/ml/management/commands/run_class_masking.py create mode 100644 ami/ml/post_processing/admin/class_masking_form.py create mode 100644 ami/ml/post_processing/admin/rank_rollup_form.py create mode 100644 ami/ml/post_processing/class_masking.py create mode 100644 ami/ml/post_processing/rank_rollup.py create mode 100644 ami/ml/post_processing/tests/test_class_masking.py create mode 100644 ami/ml/post_processing/tests/test_class_masking_admin.py diff --git a/ami/main/admin.py b/ami/main/admin.py index 404325a93..253224154 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -14,7 +14,11 @@ from ami.jobs.models import Job from ami.ml.models.project_pipeline_config import ProjectPipelineConfig from ami.ml.post_processing.admin.actions import make_post_processing_action +from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm +from ami.ml.post_processing.admin.rank_rollup_form import RankRollupActionForm from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm +from ami.ml.post_processing.class_masking import ClassMaskingTask +from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask from ami.ml.tasks import remove_duplicate_classifications @@ -552,6 +556,12 @@ def detections_count(self, obj) -> int: scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk}, name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"), ) + run_class_masking = make_post_processing_action( + ClassMaskingTask, + ClassMaskingActionForm, + scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk}, + name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"), + ) @admin.action(description="Recompute determination from current classifications and identifications") def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any]) -> None: @@ -568,7 +578,7 @@ def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any]) count += 1 self.message_user(request, f"Recomputed determination for {count} occurrence(s).") - actions = [run_small_size_filter, recompute_determination] + actions = [run_small_size_filter, run_class_masking, recompute_determination] # Order by -id (the indexed primary key) rather than -created_at, which has no # index and would force a full sort of the table to find the newest page. id @@ -850,11 +860,29 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou f"Post-processing: {task_cls.name} on Capture Set {collection.pk}" ), ) + run_class_masking = make_post_processing_action( + ClassMaskingTask, + ClassMaskingActionForm, + scope_resolver=lambda collection: {"source_image_collection_id": collection.pk}, + name_resolver=lambda task_cls, collection: ( + f"Post-processing: {task_cls.name} on Capture Set {collection.pk}" + ), + ) + run_rank_rollup = make_post_processing_action( + RankRollupTask, + RankRollupActionForm, + scope_resolver=lambda collection: {"source_image_collection_id": collection.pk}, + name_resolver=lambda task_cls, collection: ( + f"Post-processing: {task_cls.name} on Capture Set {collection.pk}" + ), + ) actions = [ populate_collection, populate_collection_async, run_small_size_filter, + run_class_masking, + run_rank_rollup, ] # Hide images many-to-many field from form. This would list all source images in the database. diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index 2855633e3..6c899c2f7 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -941,10 +941,26 @@ class ClassificationPredictionItemSerializer(serializers.Serializer): logit = serializers.FloatField(read_only=True) +class ClassificationAppliedToSerializer(serializers.ModelSerializer): + """Lightweight nested representation of the parent classification this was derived from. + + Post-processing tasks (class masking, rank rollup) record provenance via + ``Classification.applied_to``; this exposes just enough to show what a result + was derived from without recursing back into the full classification. + """ + + algorithm = AlgorithmSerializer(read_only=True) + + class Meta: + model = Classification + fields = ["id", "created_at", "algorithm"] + + class ClassificationSerializer(DefaultSerializer): taxon = TaxonNestedSerializer(read_only=True) algorithm = AlgorithmSerializer(read_only=True) top_n = ClassificationPredictionItemSerializer(many=True, read_only=True) + applied_to = ClassificationAppliedToSerializer(read_only=True) class Meta: model = Classification @@ -957,6 +973,7 @@ class Meta: "scores", "logits", "top_n", + "applied_to", "created_at", "updated_at", ] @@ -979,6 +996,8 @@ class Meta(ClassificationSerializer.Meta): class ClassificationListSerializer(DefaultSerializer): + applied_to = ClassificationAppliedToSerializer(read_only=True) + class Meta: model = Classification fields = [ @@ -987,6 +1006,7 @@ class Meta: "taxon", "score", "algorithm", + "applied_to", "created_at", "updated_at", ] @@ -1006,6 +1026,7 @@ class Meta: "score", "terminal", "algorithm", + "applied_to", "created_at", ] diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 6ad39a2e5..591a4a000 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -2060,7 +2060,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin): """ require_project_for_list = True # Unfiltered list scans are too expensive on this table - queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection") + queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm") serializer_class = ClassificationSerializer filterset_fields = [ # Docs about slow loading API browser because of large choice fields diff --git a/ami/main/models_future/occurrence.py b/ami/main/models_future/occurrence.py index 012554169..ad5cc0ac6 100644 --- a/ami/main/models_future/occurrence.py +++ b/ami/main/models_future/occurrence.py @@ -58,7 +58,10 @@ def _detections_prefetch(*, ordering: tuple[str, ...], with_source_image: bool) qs = Detection.objects.prefetch_related( Prefetch( "classifications", - queryset=Classification.objects.select_related("taxon", "algorithm"), + # applied_to__algorithm: post-processed classifications (class masking, + # rank rollup) serialize their provenance parent; pull it here so the + # nested applied_to render doesn't issue a query per classification. + queryset=Classification.objects.select_related("taxon", "algorithm", "applied_to__algorithm"), ) ).order_by(*ordering) if with_source_image: diff --git a/ami/ml/management/commands/run_class_masking.py b/ami/ml/management/commands/run_class_masking.py new file mode 100644 index 000000000..a99c0cb41 --- /dev/null +++ b/ami/ml/management/commands/run_class_masking.py @@ -0,0 +1,83 @@ +from django.core.management.base import BaseCommand, CommandError + +from ami.main.models import SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm +from ami.ml.post_processing.class_masking import ClassMaskingTask + + +class Command(BaseCommand): + help = ( + "Run class masking post-processing on a source image collection. " + "Masks classifier logits for species not in the given taxa list and recalculates softmax scores." + ) + + def add_arguments(self, parser): + parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process") + parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask") + parser.add_argument( + "--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask" + ) + parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") + + def handle(self, *args, **options): + collection_id = options["collection_id"] + taxa_list_id = options["taxa_list_id"] + algorithm_id = options["algorithm_id"] + dry_run = options["dry_run"] + + # Validate inputs + try: + collection = SourceImageCollection.objects.get(pk=collection_id) + except SourceImageCollection.DoesNotExist: + raise CommandError(f"SourceImageCollection {collection_id} does not exist.") + + try: + taxa_list = TaxaList.objects.get(pk=taxa_list_id) + except TaxaList.DoesNotExist: + raise CommandError(f"TaxaList {taxa_list_id} does not exist.") + + try: + algorithm = Algorithm.objects.get(pk=algorithm_id) + except Algorithm.DoesNotExist: + raise CommandError(f"Algorithm {algorithm_id} does not exist.") + + if not algorithm.category_map: + raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.") + + from ami.main.models import Classification + + classification_count = ( + Classification.objects.filter( + detection__source_image__collections=collection, + terminal=True, + algorithm=algorithm, + scores__isnull=False, + ) + .distinct() + .count() + ) + + taxa_count = taxa_list.taxa.count() + + self.stdout.write( + f"Collection: {collection.name} (id={collection.pk})\n" + f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n" + f"Algorithm: {algorithm.name} (id={algorithm.pk})\n" + f"Classifications to process: {classification_count}" + ) + + if classification_count == 0: + raise CommandError("No terminal classifications with scores found for this collection/algorithm.") + + if dry_run: + self.stdout.write(self.style.WARNING("Dry run — no changes made.")) + return + + self.stdout.write("Running class masking...") + task = ClassMaskingTask( + source_image_collection_id=collection_id, + taxa_list_id=taxa_list_id, + algorithm_id=algorithm_id, + ) + task.run() + self.stdout.write(self.style.SUCCESS("Class masking completed.")) diff --git a/ami/ml/post_processing/__init__.py b/ami/ml/post_processing/__init__.py index 3517ed47c..8837973c4 100644 --- a/ami/ml/post_processing/__init__.py +++ b/ami/ml/post_processing/__init__.py @@ -1 +1,3 @@ +from . import class_masking # noqa: F401 +from . import rank_rollup # noqa: F401 from . import small_size_filter # noqa: F401 diff --git a/ami/ml/post_processing/admin/class_masking_form.py b/ami/ml/post_processing/admin/class_masking_form.py new file mode 100644 index 000000000..a7077e122 --- /dev/null +++ b/ami/ml/post_processing/admin/class_masking_form.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from django import forms + +from ami.main.models import TaxaList +from ami.ml.models import Algorithm +from ami.ml.models.algorithm import AlgorithmTaskType +from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm + + +class ClassMaskingActionForm(BasePostProcessingActionForm): + """Knobs surfaced when an admin triggers Class masking. + + The operator picks the source classifier and the taxa list to keep; the + scope (which collection or occurrence) is supplied by the admin entry point, + not the form. Selections are model instances, so ``to_config`` hands the + schema their primary keys (``ClassMaskingConfig`` expects ``*_id`` ints). + """ + + algorithm_id = forms.ModelChoiceField( + queryset=Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value).order_by("name"), + label="Source classifier", + help_text="The classification algorithm whose terminal predictions will be re-scored.", + ) + taxa_list_id = forms.ModelChoiceField( + queryset=TaxaList.objects.all().order_by("name"), + label="Taxa list to keep", + help_text=( + "Classes whose taxon is not in this list are masked out; each " + "classification's softmax is renormalised over the classes that remain." + ), + ) + + def to_config(self) -> dict: + return { + "algorithm_id": self.cleaned_data["algorithm_id"].pk, + "taxa_list_id": self.cleaned_data["taxa_list_id"].pk, + } diff --git a/ami/ml/post_processing/admin/rank_rollup_form.py b/ami/ml/post_processing/admin/rank_rollup_form.py new file mode 100644 index 000000000..1d072a9b6 --- /dev/null +++ b/ami/ml/post_processing/admin/rank_rollup_form.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm + + +class RankRollupActionForm(BasePostProcessingActionForm): + """Knob form for Rank rollup. + + Rank rollup runs with the per-rank score thresholds and rollup order defined + on ``RankRollupConfig``. There are no per-run knobs yet, so the form only + confirms the selected capture set(s); the empty ``cleaned_data`` lets the + schema apply its defaults. Threshold overrides can be added here later. + """ diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py new file mode 100644 index 000000000..119cda191 --- /dev/null +++ b/ami/ml/post_processing/class_masking.py @@ -0,0 +1,267 @@ +import logging +from collections.abc import Callable + +import numpy as np +import pydantic +from django.db import transaction +from django.db.models import QuerySet +from django.utils import timezone + +from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm, AlgorithmTaskType +from ami.ml.post_processing.base import BasePostProcessingTask + +logger = logging.getLogger(__name__) + + +class ClassMaskingConfig(pydantic.BaseModel): + # Scope: exactly one of these identifies which classifications to re-score. A + # capture set is the bulk path; a single occurrence is the spot/dev path (fast + # feedback while tuning a taxa list). This mirrors SmallSizeFilterConfig's + # discriminated-scope shape — the shared pattern for per-occurrence triggers. + source_image_collection_id: int | None = None + occurrence_id: int | None = None + # The taxa list to keep: classes whose taxon is not in this list are masked out. + taxa_list_id: int + # The source classifier whose terminal classifications are re-scored. + algorithm_id: int + + @pydantic.root_validator(skip_on_failure=True) + def _exactly_one_scope(cls, values: dict) -> dict: + scopes = [values.get("source_image_collection_id"), values.get("occurrence_id")] + if sum(s is not None for s in scopes) != 1: + raise ValueError("Provide exactly one of source_image_collection_id or occurrence_id") + return values + + class Config: + extra = "forbid" + + +def make_classifications_filtered_by_taxa_list( + classifications: QuerySet[Classification], + taxa_list: TaxaList, + algorithm: Algorithm, + new_algorithm: Algorithm, + *, + task_logger: logging.Logger = logger, + progress_callback: Callable[[int, int], None] | None = None, +) -> dict[str, int]: + """Re-score ``classifications`` by masking out classes absent from ``taxa_list``. + + For each terminal classification produced by ``algorithm``, the logits of + classes whose taxon is not in ``taxa_list`` are masked, the softmax is + renormalised over the remaining classes, and a new terminal classification + (attributed to ``new_algorithm``, linked back via ``applied_to``) records the + masked prediction. The original classification is demoted to non-terminal. + + Returns counters (checked / masked / occurrences updated) for stage metrics. + """ + taxa_in_list = set(taxa_list.taxa.all()) + + total = classifications.count() + task_logger.info(f"Found {total} terminal classifications with scores to re-score.") + + if not algorithm.category_map: + raise ValueError(f"Algorithm {algorithm} does not have a category map.") + category_map = algorithm.category_map + + # Resolve each category's taxon once. Indices absent from this map, or whose + # taxon is not in the taxa list, are masked. Building included from the taxa + # list (rather than excluded from the map) means a class with no resolvable + # taxon is masked too, never silently kept. + task_logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}") + category_map_with_taxa = category_map.with_taxa() + index_to_taxon = {int(category["index"]): category["taxon"] for category in category_map_with_taxa} + num_categories = len(category_map.labels) + included_indices = [i for i in range(num_categories) if index_to_taxon.get(i) in taxa_in_list] + excluded_indices = [i for i in range(num_categories) if i not in set(included_indices)] + + if not included_indices: + raise ValueError( + f"Taxa list '{taxa_list.name}' excludes every class in algorithm '{algorithm.name}'s " + "category map; there is nothing to keep." + ) + + task_logger.info( + f"Category map has {num_categories} classes, " + f"{len(excluded_indices)} masked, {len(included_indices)} kept, " + f"{total} classifications to check" + ) + + classifications_to_demote: list[Classification] = [] + classifications_to_add: list[Classification] = [] + occurrences_to_update: set[Occurrence] = set() + + timestamp = timezone.now() + masked_count = 0 + for i, classification in enumerate(classifications.iterator(), start=1): + scores, logits = classification.scores, classification.logits + if not isinstance(logits, list) or not all(isinstance(x, (int, float)) for x in logits): + raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") + if len(logits) != num_categories: + task_logger.warning( + f"Classification {classification.pk}: {len(logits)} logits != {num_categories} categories; skipping" + ) + continue + + # Mask excluded classes with -inf on a working copy so the renormalised + # softmax assigns them exactly zero probability — an excluded class can + # never win argmax. (-inf is compute-only; it is never stored, since it + # is not valid JSON. The stored vectors stay finite: see below.) + working = np.asarray(logits, dtype=float) + working[excluded_indices] = -np.inf + working -= working.max() # max is over kept classes (finite); stabilises exp + exp = np.exp(working) # exp(-inf) == 0 for masked classes + new_scores_np = exp / exp.sum() # sum > 0: at least one class is kept + top_index = int(np.argmax(new_scores_np)) + new_scores = new_scores_np.tolist() + + # No-change short-circuit: if masking shifted no probability (the classes + # this taxa list drops carried ~zero score here), leave the row untouched. + if isinstance(scores, list) and np.allclose(scores, new_scores, atol=1e-9): + task_logger.debug(f"Classification {classification.pk} unchanged by masking; skipping") + continue + + top_taxon = index_to_taxon.get(top_index) # guaranteed in taxa_in_list (top_index is kept) + + classification.terminal = False + classification.updated_at = timestamp + + new_classification = Classification( + detection=classification.detection, + taxon=top_taxon, + algorithm=new_algorithm, + category_map=new_algorithm.category_map, + score=float(new_scores_np[top_index]), + scores=new_scores, + # Store the raw logits unchanged (JSON-safe): the mask is fully captured + # by ``scores`` (dropped classes -> 0) and the ``applied_to`` lineage. + logits=logits, + terminal=True, + timestamp=classification.timestamp, + applied_to=classification, + created_at=timestamp, + updated_at=timestamp, + ) + classifications_to_demote.append(classification) + classifications_to_add.append(new_classification) + masked_count += 1 + + detection = classification.detection + if detection is not None and detection.occurrence is not None: + occurrences_to_update.add(detection.occurrence) + + if progress_callback is not None and (i % 100 == 0 or i == total): + progress_callback(i, total) + + with transaction.atomic(): + if classifications_to_demote: + Classification.objects.bulk_update(classifications_to_demote, ["terminal", "updated_at"]) + if classifications_to_add: + Classification.objects.bulk_create(classifications_to_add) + # Recompute each affected occurrence's determination from its new terminal + # classification. + for occurrence in occurrences_to_update: + occurrence.save(update_determination=True) + + task_logger.info( + f"Re-scored {masked_count} of {total} classifications; updated {len(occurrences_to_update)} occurrences." + ) + return { + "classifications_checked": total, + "classifications_masked": masked_count, + "occurrences_updated": len(occurrences_to_update), + } + + +class ClassMaskingTask(BasePostProcessingTask): + key = "class_masking" + name = "Class masking" + config_schema = ClassMaskingConfig + + def _get_or_create_masking_algorithm(self, source_algorithm: Algorithm, taxa_list: TaxaList) -> Algorithm: + """Get or create the output algorithm for this (source algorithm, taxa list). + + One masking algorithm per pair keeps provenance reproducible: re-running + the same mask reuses the same Algorithm row. Its category map is the + source map (indices still align with the masked score vector) and is + persisted — earlier code set it in memory only, so masked classifications + referenced a null map. + """ + algorithm, created = Algorithm.objects.get_or_create( + key=f"{source_algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}", + defaults={ + "name": f"{source_algorithm.name} (filtered by taxa list {taxa_list.name})", + "description": ( + f"Classifications from {source_algorithm.name} re-scored against taxa list {taxa_list.name}" + ), + "task_type": AlgorithmTaskType.CLASSIFICATION.value, + "category_map": source_algorithm.category_map, + }, + ) + if not created and algorithm.category_map_id != source_algorithm.category_map_id: + algorithm.category_map = source_algorithm.category_map + algorithm.save(update_fields=["category_map"]) + return algorithm + + def _scoped_classifications( + self, config: ClassMaskingConfig, source_algorithm: Algorithm + ) -> tuple[QuerySet[Classification], str]: + """Resolve the terminal classifications to re-score from the config's scope. + + ``config_schema`` guarantees exactly one scope id is set, so the single + ``else`` branch is sound. + """ + base = Classification.objects.filter( + terminal=True, + algorithm=source_algorithm, + scores__isnull=False, + logits__isnull=False, + ).select_related("detection", "detection__occurrence") + + if config.occurrence_id is not None: + if not Occurrence.objects.filter(pk=config.occurrence_id).exists(): + raise ValueError(f"Occurrence {config.occurrence_id} not found") + return ( + base.filter(detection__occurrence_id=config.occurrence_id).distinct(), + f"occurrence {config.occurrence_id}", + ) + + try: + collection = SourceImageCollection.objects.get(pk=config.source_image_collection_id) + except SourceImageCollection.DoesNotExist: + raise ValueError(f"SourceImageCollection {config.source_image_collection_id} not found") + return ( + base.filter(detection__source_image__collections=collection).distinct(), + f"collection {collection.pk}", + ) + + def run(self) -> None: + config: ClassMaskingConfig = self.config # type: ignore[assignment] + self.logger.info(f"=== Starting {self.name} ===") + + try: + source_algorithm = Algorithm.objects.get(pk=config.algorithm_id) + except Algorithm.DoesNotExist: + raise ValueError(f"Algorithm {config.algorithm_id} not found") + try: + taxa_list = TaxaList.objects.get(pk=config.taxa_list_id) + except TaxaList.DoesNotExist: + raise ValueError(f"TaxaList {config.taxa_list_id} not found") + if not source_algorithm.category_map: + raise ValueError(f"Algorithm '{source_algorithm.name}' has no category map; cannot mask classes.") + + masking_algorithm = self._get_or_create_masking_algorithm(source_algorithm, taxa_list) + classifications, scope_desc = self._scoped_classifications(config, source_algorithm) + self.logger.info(f"Applying class masking on {scope_desc} using taxa list {taxa_list.pk}") + + metrics = make_classifications_filtered_by_taxa_list( + classifications=classifications, + taxa_list=taxa_list, + algorithm=source_algorithm, + new_algorithm=masking_algorithm, + task_logger=self.logger, + progress_callback=lambda i, total: self.update_progress(i / total if total else 1.0), + ) + self.report_stage_metrics(metrics) + self.logger.info(f"=== Completed {self.name} ===") diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py new file mode 100644 index 000000000..7e8e2764e --- /dev/null +++ b/ami/ml/post_processing/rank_rollup.py @@ -0,0 +1,205 @@ +import logging +from collections import defaultdict + +import pydantic +from django.db import transaction +from django.utils import timezone + +from ami.main.models import Classification, Taxon +from ami.ml.models.algorithm import AlgorithmCategoryMap +from ami.ml.post_processing.base import BasePostProcessingTask + +logger = logging.getLogger(__name__) + +DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4} +DEFAULT_ROLLUP_ORDER = ["SPECIES", "GENUS", "FAMILY"] + + +class RankRollupConfig(pydantic.BaseModel): + source_image_collection_id: int + # Minimum aggregated score required to roll a classification up to each rank. + thresholds: dict[str, float] = DEFAULT_THRESHOLDS + # Ranks to try, finest first; the first whose aggregated score clears its + # threshold wins. + rollup_order: list[str] = DEFAULT_ROLLUP_ORDER + + @pydantic.validator("thresholds") + def _validate_thresholds(cls, value: dict[str, float]) -> dict[str, float]: + normalised: dict[str, float] = {} + for rank, threshold in value.items(): + if not (0.0 < threshold <= 1.0): + raise ValueError(f"Threshold for {rank} must be in (0, 1]") + normalised[rank.upper()] = threshold + return normalised + + @pydantic.validator("rollup_order") + def _uppercase_order(cls, value: list[str]) -> list[str]: + return [rank.upper() for rank in value] + + class Config: + extra = "forbid" + + +def find_ancestor_by_parent_chain(taxon, target_rank: str): + """Climb up parent relationships until a taxon with the target rank is found.""" + if not taxon: + return None + + target_rank = target_rank.upper() + + current = taxon + while current: + if current.rank.upper() == target_rank: + return current + current = current.parent + + return None + + +class RankRollupTask(BasePostProcessingTask): + """Post-processing task that rolls up low-confidence classifications + to higher ranks using aggregated scores. + """ + + key = "rank_rollup" + name = "Rank rollup" + config_schema = RankRollupConfig + + def run(self) -> None: + config: RankRollupConfig = self.config # type: ignore[assignment] + job = self.job + self.logger.info(f"Starting {self.name} task for job {job.pk if job else 'N/A'}") + + collection_id = config.source_image_collection_id + thresholds = config.thresholds + rollup_order = config.rollup_order + + self.logger.info( + f"Config loaded: collection_id={collection_id}, thresholds={thresholds}, rollup_order={rollup_order}" + ) + + # select_related the per-row relations the loop touches (category_map for + # labels, detection/occurrence for the rollup write) so the body issues no + # per-classification queries. + qs = ( + Classification.objects.filter( + terminal=True, + taxon__isnull=False, + detection__source_image__collections__id=collection_id, + ) + .select_related("category_map", "taxon", "detection", "detection__occurrence") + .distinct() + ) + + total = qs.count() + self.logger.info(f"Found {total} terminal classifications to process for collection {collection_id}") + + # Pre-load every label across the distinct category maps in one pass (two + # queries total), instead of dereferencing clf.category_map per row. + category_map_ids = list(qs.values_list("category_map_id", flat=True).distinct()) + all_labels: set[str] = set() + for category_map in AlgorithmCategoryMap.objects.filter(pk__in=category_map_ids): + if category_map.labels: + all_labels.update(label for label in category_map.labels if label) + + label_to_taxon = {} + if all_labels: + for taxon in Taxon.objects.filter(name__in=all_labels).select_related("parent"): + label_to_taxon[taxon.name] = taxon + self.logger.info(f"Pre-loaded {len(label_to_taxon)} taxa from {len(all_labels)} unique labels") + + updated_occurrences = [] + + with transaction.atomic(): + for i, clf in enumerate(qs.iterator(), start=1): + score_str = f"{clf.score:.3f}" if clf.score is not None else "N/A" + self.logger.info(f"Processing classification #{clf.pk} (taxon={clf.taxon}, score={score_str})") + + if not clf.scores: + self.logger.info(f"Skipping classification #{clf.pk}: no scores available") + continue + if not clf.category_map: + self.logger.info(f"Skipping classification #{clf.pk}: no category_map assigned") + continue + + taxon_scores = defaultdict(float) + + for idx, score in enumerate(clf.scores): + label = clf.category_map.labels[idx] + if not label: + continue + + taxon = label_to_taxon.get(label) + if not taxon: + self.logger.debug(f"Skipping label '{label}' (no matching Taxon found)") + continue + + for rank in rollup_order: + ancestor = find_ancestor_by_parent_chain(taxon, rank) + if ancestor: + taxon_scores[ancestor] += score + self.logger.debug(f" + Added {score:.3f} to ancestor {ancestor.name} ({rank})") + + new_taxon = None + new_score = None + scores_str = {t.name: s for t, s in taxon_scores.items()} + self.logger.info(f"Aggregated taxon scores: {scores_str}") + for rank in rollup_order: + threshold = thresholds.get(rank, 1.0) + candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank} + + if not candidates: + self.logger.info(f"No candidates found at rank {rank}") + continue + + best_taxon, best_score = max(candidates.items(), key=lambda kv: kv[1]) + self.logger.info( + f"Best at rank {rank}: {best_taxon.name} ({best_score:.3f}) [threshold={threshold}]" + ) + + if best_score >= threshold: + new_taxon, new_score = best_taxon, best_score + self.logger.info(f"Rollup decision: {new_taxon.name} ({rank}) with score {new_score:.3f}") + break + + if new_taxon and new_taxon != clf.taxon: + self.logger.info(f"Rolling up {clf.taxon} => {new_taxon} ({new_taxon.rank})") + + # Mark all classifications for this detection as non-terminal + Classification.objects.filter(detection=clf.detection).update(terminal=False) + Classification.objects.create( + detection=clf.detection, + taxon=new_taxon, + score=new_score, + terminal=True, + algorithm=self.algorithm, + timestamp=timezone.now(), + applied_to=clf, + ) + + occurrence = clf.detection.occurrence + if occurrence: + occurrence.save(update_determination=True) + updated_occurrences.append(occurrence) + self.logger.info( + f"Rolled up occurrence {occurrence.pk}: {clf.taxon} => {new_taxon} " + f"({new_taxon.rank}) with rolled-up score={new_score:.3f}" + ) + else: + self.logger.warning(f"Detection #{clf.detection.pk} has no occurrence; skipping.") + else: + self.logger.info(f"No rollup applied for classification #{clf.pk} (taxon={clf.taxon})") + + # Update progress every 10 iterations + if i % 10 == 0 or i == total: + progress = i / total if total > 0 else 1.0 + self.update_progress(progress) + + self.report_stage_metrics( + { + "classifications_checked": total, + "occurrences_rolled_up": len(updated_occurrences), + } + ) + self.logger.info(f"Rank rollup completed. Updated {len(updated_occurrences)} occurrences.") + self.logger.info(f"{self.name} task finished for collection {collection_id}.") diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py index c85f607f9..28fa7fb2f 100644 --- a/ami/ml/post_processing/registry.py +++ b/ami/ml/post_processing/registry.py @@ -1,8 +1,12 @@ # Registry of available post-processing tasks +from ami.ml.post_processing.class_masking import ClassMaskingTask +from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask POSTPROCESSING_TASKS = { SmallSizeFilterTask.key: SmallSizeFilterTask, + ClassMaskingTask.key: ClassMaskingTask, + RankRollupTask.key: RankRollupTask, } diff --git a/ami/ml/post_processing/tests/test_class_masking.py b/ami/ml/post_processing/tests/test_class_masking.py new file mode 100644 index 000000000..4503a3d49 --- /dev/null +++ b/ami/ml/post_processing/tests/test_class_masking.py @@ -0,0 +1,302 @@ +"""Domain tests for class masking and rank rollup post-processing tasks. + +Class masking re-scores a classifier's terminal predictions against a taxa list: +classes whose taxon is not in the list are masked, the softmax is renormalised +over the rest, and a new terminal classification (linked back via ``applied_to``) +records the masked result. These tests cover the masking maths (including that an +excluded class can never win even when it had the highest logit), the provenance +link, the persisted output algorithm, and both admin scopes (collection / single +occurrence) end to end through ``ClassMaskingTask.run()``. +""" +import datetime +import math +import pathlib +import uuid + +from django.test import TestCase + +from ami.main.models import ( + Classification, + Detection, + Occurrence, + SourceImage, + SourceImageCollection, + TaxaList, + Taxon, + TaxonRank, + group_images_into_events, +) +from ami.ml.models import Algorithm, AlgorithmCategoryMap +from ami.ml.models.algorithm import AlgorithmTaskType +from ami.ml.post_processing.class_masking import ClassMaskingTask, make_classifications_filtered_by_taxa_list +from ami.ml.post_processing.rank_rollup import RankRollupTask +from ami.tests.fixtures.main import create_taxa, setup_test_project + + +def _softmax(logits: list[float]) -> list[float]: + shifted = [x - max(logits) for x in logits] + exp = [math.exp(x) for x in shifted] + total = sum(exp) + return [e / total for e in exp] + + +class TestPostProcessingClassMasking(TestCase): + def setUp(self): + self.project, self.deployment = setup_test_project() + create_taxa(project=self.project) + self._create_images_with_dimensions(deployment=self.deployment) + group_images_into_events(deployment=self.deployment) + + self.collection = SourceImageCollection.objects.create( + name="Test PostProcessing Collection", + project=self.project, + method="manual", + kwargs={"image_ids": list(self.deployment.captures.values_list("pk", flat=True))}, + ) + self.collection.populate_sample() + + self.species_taxon = Taxon.objects.filter(rank=TaxonRank.SPECIES.name).first() + self.genus_taxon = self.species_taxon.parent if self.species_taxon else None + self.assertIsNotNone(self.species_taxon) + self.assertIsNotNone(self.genus_taxon) + self.algorithm = self._create_category_map_with_algorithm() + self.species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + + # ----- fixtures ------------------------------------------------------- + + def _create_images_with_dimensions(self, deployment, num_images=5, width=640, height=480): + base_time = datetime.datetime.now(datetime.timezone.utc) + for i in range(num_images): + path = pathlib.Path("test") / f"{uuid.uuid4().hex[:8]}_{i}.jpg" + SourceImage.objects.create( + deployment=deployment, + project=deployment.project, + timestamp=base_time + datetime.timedelta(minutes=i * 5), + path=path, + width=width, + height=height, + ) + deployment.save(update_calculated_fields=True, regroup_async=False) + + def _create_category_map_with_algorithm(self) -> Algorithm: + species_taxa = list(self.project.taxa.filter(rank=TaxonRank.SPECIES.name).order_by("name")[:3]) + assert species_taxa, "No species taxa found in project; run create_taxa() first." + data = [{"index": i, "label": taxon.name} for i, taxon in enumerate(species_taxa)] + category_map = AlgorithmCategoryMap.objects.create( + data=data, + labels=[item["label"] for item in data], + version="v1.0", + description="Species-level category map for testing", + ) + return Algorithm.objects.create( + name="Test Species Classifier", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=category_map, + ) + + def _create_classification_with_logits(self, detection, taxon, scores, logits) -> Classification: + return Classification.objects.create( + detection=detection, + taxon=taxon, + score=max(scores), + scores=scores, + logits=logits, + terminal=True, + timestamp=datetime.datetime.now(datetime.timezone.utc), + algorithm=self.algorithm, + ) + + def _detection_with_occurrence(self) -> tuple[Detection, Occurrence]: + det = Detection.objects.create(source_image=self.collection.images.first(), bbox=[0, 0, 200, 200]) + occ = Occurrence.objects.create(project=self.project, event=self.deployment.events.first()) + occ.detections.add(det) + return det, occ + + # ----- make_classifications_filtered_by_taxa_list --------------------- + + def test_excluded_class_never_wins_even_with_highest_logit(self): + """The core guarantee: a masked class cannot be selected, even if it had + the single highest logit before masking.""" + # index 2 (excluded) has the highest logit, so it is the original top. + logits = [2.0, 1.0, 5.0] + scores = _softmax(logits) + self.assertEqual(scores.index(max(scores)), 2) + + taxa_list = TaxaList.objects.create(name="Keep first two") + taxa_list.taxa.set(self.species_taxa[:2]) # excludes species_taxa[2] + + det, _ = self._detection_with_occurrence() + original = self._create_classification_with_logits(det, self.species_taxa[2], scores, logits) + + new_algorithm = Algorithm.objects.create( + name="masked", + key="masked_test", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + metrics = make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk=original.pk), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + original.refresh_from_db() + self.assertFalse(original.terminal, "Source classification is demoted to non-terminal") + new_clf = Classification.objects.get(detection=det, terminal=True) + # Highest logit among the kept classes is index 0. + self.assertEqual(new_clf.taxon, self.species_taxa[0]) + self.assertAlmostEqual(new_clf.scores[2], 0.0, places=10, msg="Masked class score is exactly zero") + self.assertAlmostEqual(sum(new_clf.scores), 1.0, places=5) + self.assertEqual(new_clf.applied_to, original, "Provenance links back to the source classification") + self.assertEqual(metrics["classifications_masked"], 1) + + def test_single_allowed_class_gets_all_probability(self): + logits = [2.0, 3.0, 4.0] + taxa_list = TaxaList.objects.create(name="Keep one") + taxa_list.taxa.set([self.species_taxa[0]]) + + det, _ = self._detection_with_occurrence() + self._create_classification_with_logits(det, self.species_taxa[2], _softmax(logits), logits) + + new_algorithm = Algorithm.objects.create( + name="masked2", + key="masked_test2", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(detection=det, terminal=True), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + new_clf = Classification.objects.get(detection=det, terminal=True) + self.assertAlmostEqual(new_clf.scores[0], 1.0, places=5) + self.assertAlmostEqual(new_clf.scores[1], 0.0, places=10) + self.assertAlmostEqual(new_clf.scores[2], 0.0, places=10) + + def test_no_change_when_all_classes_in_list(self): + logits = [3.0, 1.0, 0.5] + taxa_list = TaxaList.objects.create(name="Keep all") + taxa_list.taxa.set(self.species_taxa) + + det, _ = self._detection_with_occurrence() + original = self._create_classification_with_logits(det, self.species_taxa[0], _softmax(logits), logits) + + new_algorithm = Algorithm.objects.create( + name="masked3", + key="masked_test3", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk=original.pk), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + original.refresh_from_db() + self.assertTrue(original.terminal, "Nothing masked, so the source stays terminal") + self.assertEqual(Classification.objects.filter(detection=det).count(), 1, "No new classification created") + + def test_all_classes_excluded_raises(self): + # A taxa list sharing nothing with the category map leaves no class to keep. + taxa_list = TaxaList.objects.create(name="Unrelated") + taxa_list.taxa.set([self.genus_taxon]) # genus name is not a category-map label + + det, _ = self._detection_with_occurrence() + logits = [2.0, 1.0, 5.0] + self._create_classification_with_logits(det, self.species_taxa[2], _softmax(logits), logits) + + new_algorithm = Algorithm.objects.create( + name="masked4", + key="masked_test4", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + with self.assertRaises(ValueError): + make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(detection=det, terminal=True), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + # ----- ClassMaskingTask.run() end to end ------------------------------ + + def test_task_run_collection_scope_persists_masking_algorithm(self): + logits = [0.5, 3.0, 3.5] # excluded index 2 is top; index 1 is the in-list winner + taxa_list = TaxaList.objects.create(name="Regional list") + taxa_list.taxa.set(self.species_taxa[:2]) + + det, occ = self._detection_with_occurrence() + original = self._create_classification_with_logits(det, self.species_taxa[2], _softmax(logits), logits) + + ClassMaskingTask( + source_image_collection_id=self.collection.pk, + taxa_list_id=taxa_list.pk, + algorithm_id=self.algorithm.pk, + ).run() + + # The per-(source algorithm, taxa list) masking algorithm exists and kept + # its category map (the bug being guarded: it used to be set in memory only). + masking_algo = Algorithm.objects.get(key=f"{self.algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}") + self.assertIsNotNone(masking_algo.category_map_id) + self.assertEqual(masking_algo.category_map_id, self.algorithm.category_map_id) + + new_clf = Classification.objects.get(detection=det, terminal=True, algorithm=masking_algo) + self.assertEqual(new_clf.taxon, self.species_taxa[1]) + self.assertEqual(new_clf.applied_to, original) + occ.refresh_from_db() + self.assertEqual(occ.determination, self.species_taxa[1], "Occurrence determination follows the masked result") + + def test_task_run_occurrence_scope(self): + logits = [2.0, 1.0, 5.0] + taxa_list = TaxaList.objects.create(name="Occ scope list") + taxa_list.taxa.set(self.species_taxa[:2]) + + det, occ = self._detection_with_occurrence() + self._create_classification_with_logits(det, self.species_taxa[2], _softmax(logits), logits) + + ClassMaskingTask( + occurrence_id=occ.pk, + taxa_list_id=taxa_list.pk, + algorithm_id=self.algorithm.pk, + ).run() + + new_clf = Classification.objects.filter(detection=det, terminal=True).exclude(algorithm=self.algorithm).first() + self.assertIsNotNone(new_clf) + self.assertEqual(new_clf.taxon, self.species_taxa[0]) + + # ----- rank rollup ---------------------------------------------------- + + def test_rank_rollup_creates_genus_terminal_classification(self): + now = datetime.datetime.now(datetime.timezone.utc) + originals = [] + for _ in range(3): + det, _occ = self._detection_with_occurrence() + originals.append( + Classification.objects.create( + detection=det, + taxon=self.species_taxon, + score=0.5, + scores=[0.5, 0.2, 0.1], + terminal=True, + timestamp=now, + algorithm=self.algorithm, + ) + ) + + RankRollupTask( + source_image_collection_id=self.collection.pk, + thresholds={"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4}, + ).run() + + for original in originals: + original.refresh_from_db(fields=["terminal"]) + self.assertFalse(original.terminal) + rolled = Classification.objects.filter(detection=original.detection, terminal=True).first() + self.assertIsNotNone(rolled) + self.assertEqual(rolled.taxon, self.genus_taxon) + self.assertEqual(rolled.applied_to, original) diff --git a/ami/ml/post_processing/tests/test_class_masking_admin.py b/ami/ml/post_processing/tests/test_class_masking_admin.py new file mode 100644 index 000000000..171d588f1 --- /dev/null +++ b/ami/ml/post_processing/tests/test_class_masking_admin.py @@ -0,0 +1,150 @@ +"""Schema validation + admin-action wiring tests for class masking and rank rollup. + +These are deliberately lightweight: they exercise the pydantic config contracts +and the admin trigger flow (intermediate page -> Job creation with the right +config payload) without the full project fixture. The masking maths is covered +in ``test_class_masking``. +""" +import pydantic +from django.contrib import admin as django_admin +from django.test import Client, TestCase +from django.urls import reverse + +from ami.jobs.models import Job +from ami.main.models import Occurrence, Project, SourceImageCollection, TaxaList +from ami.ml.models import Algorithm +from ami.ml.models.algorithm import AlgorithmTaskType +from ami.ml.post_processing.class_masking import ClassMaskingConfig +from ami.ml.post_processing.rank_rollup import RankRollupConfig +from ami.users.models import User + + +class TestClassMaskingConfig(TestCase): + def test_collection_scope_is_valid(self): + config = ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3) + self.assertEqual(config.source_image_collection_id, 1) + self.assertIsNone(config.occurrence_id) + + def test_occurrence_scope_is_valid(self): + config = ClassMaskingConfig(occurrence_id=5, taxa_list_id=2, algorithm_id=3) + self.assertEqual(config.occurrence_id, 5) + + def test_both_scopes_is_invalid(self): + with self.assertRaises(pydantic.ValidationError): + ClassMaskingConfig(source_image_collection_id=1, occurrence_id=5, taxa_list_id=2, algorithm_id=3) + + def test_no_scope_is_invalid(self): + with self.assertRaises(pydantic.ValidationError): + ClassMaskingConfig(taxa_list_id=2, algorithm_id=3) + + def test_missing_required_fields_is_invalid(self): + with self.assertRaises(pydantic.ValidationError): + ClassMaskingConfig(source_image_collection_id=1) # no taxa_list_id / algorithm_id + + def test_extra_field_is_forbidden(self): + with self.assertRaises(pydantic.ValidationError): + ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3, bogus=1) + + +class TestRankRollupConfig(TestCase): + def test_defaults_applied(self): + config = RankRollupConfig(source_image_collection_id=1) + self.assertEqual(config.thresholds["SPECIES"], 0.8) + self.assertEqual(config.rollup_order, ["SPECIES", "GENUS", "FAMILY"]) + + def test_threshold_out_of_range_is_invalid(self): + with self.assertRaises(pydantic.ValidationError): + RankRollupConfig(source_image_collection_id=1, thresholds={"SPECIES": 1.5}) + + def test_threshold_and_order_are_uppercased(self): + config = RankRollupConfig( + source_image_collection_id=1, thresholds={"species": 0.7}, rollup_order=["species", "genus"] + ) + self.assertIn("SPECIES", config.thresholds) + self.assertEqual(config.rollup_order, ["SPECIES", "GENUS"]) + + +class _PostProcessingAdminCase(TestCase): + @classmethod + def setUpTestData(cls) -> None: + cls.superuser = User.objects.create_superuser(email=f"ppadmin+{cls.__name__}@example.com", password="x") + cls.project = Project.objects.create(name=f"PP admin test ({cls.__name__})") + cls.collection = SourceImageCollection.objects.create(project=cls.project, name="PP admin collection") + cls.occurrence = Occurrence.objects.create(project=cls.project) + cls.taxa_list = TaxaList.objects.create(name="PP admin taxa list") + cls.algorithm = Algorithm.objects.create( + name="PP admin classifier", task_type=AlgorithmTaskType.CLASSIFICATION.value + ) + + def setUp(self) -> None: + self.client = Client() + self.client.force_login(self.superuser) + + +class TestClassMaskingAdmin(_PostProcessingAdminCase): + def _post_collection(self, data: dict): + url = reverse("admin:main_sourceimagecollection_changelist") + return self.client.post( + url, + data={ + "action": "run_class_masking", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.collection.pk)], + **data, + }, + ) + + def test_renders_intermediate_page_without_confirm(self): + response = self._post_collection({}) + self.assertEqual(response.status_code, 200) + self.assertIn(b"Run Class masking", response.content) + self.assertIn(b'name="taxa_list_id"', response.content) + self.assertIn(b'name="algorithm_id"', response.content) + self.assertEqual(Job.objects.filter(project=self.project).count(), 0) + + def test_valid_post_creates_collection_scoped_job(self): + response = self._post_collection( + {"confirm": "yes", "taxa_list_id": str(self.taxa_list.pk), "algorithm_id": str(self.algorithm.pk)} + ) + self.assertEqual(response.status_code, 302) + job = Job.objects.get(project=self.project, job_type_key="post_processing") + self.assertEqual(job.params["task"], "class_masking") + self.assertEqual(job.params["config"]["source_image_collection_id"], self.collection.pk) + self.assertEqual(job.params["config"]["taxa_list_id"], self.taxa_list.pk) + self.assertEqual(job.params["config"]["algorithm_id"], self.algorithm.pk) + self.assertIsNone(job.params["config"].get("occurrence_id")) + + def test_valid_post_on_occurrence_creates_occurrence_scoped_job(self): + url = reverse("admin:main_occurrence_changelist") + response = self.client.post( + url, + data={ + "action": "run_class_masking", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.occurrence.pk)], + "confirm": "yes", + "taxa_list_id": str(self.taxa_list.pk), + "algorithm_id": str(self.algorithm.pk), + }, + ) + self.assertEqual(response.status_code, 302) + job = Job.objects.get(project=self.project, job_type_key="post_processing") + self.assertEqual(job.params["task"], "class_masking") + self.assertEqual(job.params["config"]["occurrence_id"], self.occurrence.pk) + self.assertIsNone(job.params["config"].get("source_image_collection_id")) + + +class TestRankRollupAdmin(_PostProcessingAdminCase): + def test_valid_post_creates_rank_rollup_job_with_defaults(self): + url = reverse("admin:main_sourceimagecollection_changelist") + response = self.client.post( + url, + data={ + "action": "run_rank_rollup", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.collection.pk)], + "confirm": "yes", + }, + ) + self.assertEqual(response.status_code, 302) + job = Job.objects.get(project=self.project, job_type_key="post_processing") + self.assertEqual(job.params["task"], "rank_rollup") + self.assertEqual(job.params["config"]["source_image_collection_id"], self.collection.pk) + self.assertEqual(job.params["config"]["thresholds"]["SPECIES"], 0.8) From 2dbffabeb4cc4be1e611f1c0a0268f186a90eac0 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 25 Jun 2026 11:23:56 -0700 Subject: [PATCH 2/8] docs(post-processing): note rank_rollup is not lineage-constrained The roll-up picks the global argmax over every taxon at each rank, not just ancestors of the source classification's taxon. On a diffuse, low-confidence distribution this can reparent a detection to a family outside its own lineage. Document the behavior at the candidate-pick loop and record the open design choice (lineage-constrained vs distribution roll-up) as a TODO; no behavior change. Co-Authored-By: Claude --- ami/ml/post_processing/rank_rollup.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py index 7e8e2764e..179d515d7 100644 --- a/ami/ml/post_processing/rank_rollup.py +++ b/ami/ml/post_processing/rank_rollup.py @@ -144,6 +144,20 @@ def run(self) -> None: new_score = None scores_str = {t.name: s for t, s in taxon_scores.items()} self.logger.info(f"Aggregated taxon scores: {scores_str}") + # The candidates at each rank are every taxon that accumulated + # score there — the global argmax across the whole distribution, + # not only ancestors of clf.taxon. For a confident classification + # the winner is its own lineage, but a diffuse, low-confidence + # distribution can spread enough mass across unrelated branches + # that the top taxon at a rank is not an ancestor of clf.taxon, so + # the roll-up reparents the detection to an unrelated family. + # TODO: decide the intended semantics: + # - lineage-constrained: restrict candidates to ancestors of + # clf.taxon (find_ancestor_by_parent_chain already yields them) + # so a roll-up only ever generalizes the original prediction; or + # - distribution roll-up: keep the global argmax but document it, + # and reconsider whether applied_to -> this single clf is the + # right provenance when the result is outside its lineage. for rank in rollup_order: threshold = thresholds.get(rank, 1.0) candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank} From ba796c127ceac5316d690ce7ffa37bb4d983dc9c Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 25 Jun 2026 16:39:29 -0700 Subject: [PATCH 3/8] feat(post-processing): scope the class-mask algorithm list to the selection When an operator triggers class masking on an occurrence (or collection), the "Source classifier" dropdown now lists only the classification algorithms that actually produced classifications within the selected scope. Masking any other algorithm would be a no-op for those rows, so offering every classifier was misleading. The admin action hands the knob form the selected queryset (a small generic seam on the form base, ignored by forms that don't need it), and ClassMaskingActionForm uses it to filter the algorithm field by classifications__detection__occurrence / source_image__collections. Co-Authored-By: Claude --- ami/ml/post_processing/admin/actions.py | 6 +- .../admin/class_masking_form.py | 23 +++++++- ami/ml/post_processing/admin/forms.py | 10 ++++ .../tests/test_class_masking_admin.py | 56 ++++++++++++++++++- 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/ami/ml/post_processing/admin/actions.py b/ami/ml/post_processing/admin/actions.py index 07a299c1f..4e7729256 100644 --- a/ami/ml/post_processing/admin/actions.py +++ b/ami/ml/post_processing/admin/actions.py @@ -247,10 +247,12 @@ def _render(form: BasePostProcessingActionForm) -> TemplateResponse: ) return None + # Hand the form the selected rows so it can scope its fields to the + # selection (e.g. only offer algorithms that ran on the chosen occurrence). if not request.POST.get("confirm"): - return _render(form_class()) + return _render(form_class(scope_queryset=queryset)) - form = form_class(request.POST) + form = form_class(request.POST, scope_queryset=queryset) if not form.is_valid(): return _render(form) diff --git a/ami/ml/post_processing/admin/class_masking_form.py b/ami/ml/post_processing/admin/class_masking_form.py index a7077e122..dafd7467b 100644 --- a/ami/ml/post_processing/admin/class_masking_form.py +++ b/ami/ml/post_processing/admin/class_masking_form.py @@ -2,7 +2,7 @@ from django import forms -from ami.main.models import TaxaList +from ami.main.models import Occurrence, SourceImageCollection, TaxaList from ami.ml.models import Algorithm from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm @@ -31,6 +31,27 @@ class ClassMaskingActionForm(BasePostProcessingActionForm): ), ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # When the admin hands us the selected scope, only offer classifiers that + # actually produced classifications there — masking any other algorithm is + # a no-op for the selected rows. Without a scope (e.g. used standalone) the + # field keeps its full classifier list. + if self.scope_queryset is not None: + self.fields["algorithm_id"].queryset = self._algorithms_for_scope(self.scope_queryset) + + @staticmethod + def _algorithms_for_scope(scope_queryset): + """Classification algorithms that produced classifications within the + selected scope (the chosen occurrences or collections).""" + algorithms = Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value) + model = scope_queryset.model + if model is Occurrence: + algorithms = algorithms.filter(classifications__detection__occurrence__in=scope_queryset) + elif model is SourceImageCollection: + algorithms = algorithms.filter(classifications__detection__source_image__collections__in=scope_queryset) + return algorithms.distinct().order_by("name") + def to_config(self) -> dict: return { "algorithm_id": self.cleaned_data["algorithm_id"].pk, diff --git a/ami/ml/post_processing/admin/forms.py b/ami/ml/post_processing/admin/forms.py index 7a2dbf5c9..c9c162808 100644 --- a/ami/ml/post_processing/admin/forms.py +++ b/ami/ml/post_processing/admin/forms.py @@ -20,6 +20,16 @@ class BasePostProcessingActionForm(forms.Form): optional fields, derive computed values, rename keys). """ + def __init__(self, *args, scope_queryset=None, **kwargs): + """Capture the admin selection the action will run on. + + ``scope_queryset`` is the queryset of rows the operator picked (e.g. the + chosen occurrences or collections). Subclasses may use it to constrain + their fields to that selection; forms that don't need it ignore it. + """ + self.scope_queryset = scope_queryset + super().__init__(*args, **kwargs) + def to_config(self) -> dict: """Return ``cleaned_data`` shaped for ``Job.params['config']``.""" return dict(self.cleaned_data) diff --git a/ami/ml/post_processing/tests/test_class_masking_admin.py b/ami/ml/post_processing/tests/test_class_masking_admin.py index 171d588f1..5880099af 100644 --- a/ami/ml/post_processing/tests/test_class_masking_admin.py +++ b/ami/ml/post_processing/tests/test_class_masking_admin.py @@ -9,11 +9,22 @@ from django.contrib import admin as django_admin from django.test import Client, TestCase from django.urls import reverse +from django.utils import timezone from ami.jobs.models import Job -from ami.main.models import Occurrence, Project, SourceImageCollection, TaxaList +from ami.main.models import ( + Classification, + Deployment, + Detection, + Occurrence, + Project, + SourceImage, + SourceImageCollection, + TaxaList, +) from ami.ml.models import Algorithm from ami.ml.models.algorithm import AlgorithmTaskType +from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm from ami.ml.post_processing.class_masking import ClassMaskingConfig from ami.ml.post_processing.rank_rollup import RankRollupConfig from ami.users.models import User @@ -75,6 +86,14 @@ def setUpTestData(cls) -> None: cls.algorithm = Algorithm.objects.create( name="PP admin classifier", task_type=AlgorithmTaskType.CLASSIFICATION.value ) + # Wire the classifier to both scopes (the collection's image and the + # occurrence) so the class-mask form offers it — it only lists algorithms + # that actually produced classifications within the selection. + cls.deployment = Deployment.objects.create(project=cls.project, name="PP admin dep") + source_image = SourceImage.objects.create(deployment=cls.deployment, project=cls.project, path="pp-admin.jpg") + cls.collection.images.add(source_image) + detection = Detection.objects.create(source_image=source_image, bbox=[0, 0, 1, 1], occurrence=cls.occurrence) + Classification.objects.create(detection=detection, algorithm=cls.algorithm, timestamp=timezone.now()) def setUp(self) -> None: self.client = Client() @@ -148,3 +167,38 @@ def test_valid_post_creates_rank_rollup_job_with_defaults(self): self.assertEqual(job.params["task"], "rank_rollup") self.assertEqual(job.params["config"]["source_image_collection_id"], self.collection.pk) self.assertEqual(job.params["config"]["thresholds"]["SPECIES"], 0.8) + + +class TestClassMaskingFormScopeFiltering(TestCase): + """The class-mask form offers only classifiers that actually produced + classifications within the selected scope, so an operator cannot pick an + algorithm whose masking would be a no-op for the chosen occurrence.""" + + @classmethod + def setUpTestData(cls) -> None: + cls.project = Project.objects.create(name="CM scope filter project") + cls.deployment = Deployment.objects.create(project=cls.project, name="dep") + cls.source_image = SourceImage.objects.create( + deployment=cls.deployment, project=cls.project, path="cm-scope.jpg" + ) + cls.used = Algorithm.objects.create(name="used classifier", task_type=AlgorithmTaskType.CLASSIFICATION.value) + cls.unused = Algorithm.objects.create( + name="unused classifier", task_type=AlgorithmTaskType.CLASSIFICATION.value + ) + + cls.occurrence = Occurrence.objects.create(project=cls.project, deployment=cls.deployment) + detection = Detection.objects.create( + source_image=cls.source_image, bbox=[0, 0, 1, 1], occurrence=cls.occurrence + ) + Classification.objects.create(detection=detection, algorithm=cls.used, timestamp=timezone.now()) + + def test_form_offers_only_algorithms_used_on_the_occurrence(self): + form = ClassMaskingActionForm(scope_queryset=Occurrence.objects.filter(pk=self.occurrence.pk)) + offered = set(form.fields["algorithm_id"].queryset.values_list("pk", flat=True)) + self.assertEqual(offered, {self.used.pk}) + + def test_form_without_scope_offers_all_classifiers(self): + form = ClassMaskingActionForm() + offered = set(form.fields["algorithm_id"].queryset.values_list("pk", flat=True)) + self.assertIn(self.used.pk, offered) + self.assertIn(self.unused.pk, offered) From 7f97cd502fe4246532a42079a4aec22cf9b8d6b4 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 30 Jun 2026 11:45:52 -0700 Subject: [PATCH 4/8] fix(post-processing): stop the class-mask form timing out on large collections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The class-mask admin form narrows its "Source classifier" dropdown to the algorithms that actually produced classifications in the selected scope, so an operator cannot pick a classifier whose masking would be a no-op. On the occurrence path that lookup is cheap (a handful of classifications), but on the collection path it became an unbounded DISTINCT over every classification in the collection. On a large collection that join runs for tens of seconds or times out while the intermediate form is still rendering, before the operator can do anything. Scope the dropdown only for the occurrence path. A collection scope now keeps the full classifier list. Offering a classifier that produced nothing in the collection is harmless — masking it changes no rows — so the narrowing was a convenience, not a correctness guard, and is not worth a query that can hang the page. Co-Authored-By: Claude --- .../admin/class_masking_form.py | 35 +++++++++++-------- .../tests/test_class_masking_admin.py | 14 ++++++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/ami/ml/post_processing/admin/class_masking_form.py b/ami/ml/post_processing/admin/class_masking_form.py index dafd7467b..a290350ce 100644 --- a/ami/ml/post_processing/admin/class_masking_form.py +++ b/ami/ml/post_processing/admin/class_masking_form.py @@ -2,7 +2,7 @@ from django import forms -from ami.main.models import Occurrence, SourceImageCollection, TaxaList +from ami.main.models import Occurrence, TaxaList from ami.ml.models import Algorithm from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm @@ -33,24 +33,31 @@ class ClassMaskingActionForm(BasePostProcessingActionForm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # When the admin hands us the selected scope, only offer classifiers that - # actually produced classifications there — masking any other algorithm is - # a no-op for the selected rows. Without a scope (e.g. used standalone) the - # field keeps its full classifier list. - if self.scope_queryset is not None: + # Narrow the classifier dropdown to algorithms that actually produced + # classifications in the selected scope, so the operator cannot pick a + # classifier whose masking would be a no-op for the chosen rows. This is + # only done for an occurrence scope, where the lookup touches the handful + # of classifications under the picked occurrences. A collection scope + # keeps the full classifier list on purpose: the equivalent lookup is an + # unbounded DISTINCT over every classification in the collection (hundreds + # of thousands of rows on a large collection) and can time out while the + # form renders. An over-broad option is harmless — masking a classifier + # that produced nothing in scope changes nothing. + if self.scope_queryset is not None and self.scope_queryset.model is Occurrence: self.fields["algorithm_id"].queryset = self._algorithms_for_scope(self.scope_queryset) @staticmethod def _algorithms_for_scope(scope_queryset): """Classification algorithms that produced classifications within the - selected scope (the chosen occurrences or collections).""" - algorithms = Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value) - model = scope_queryset.model - if model is Occurrence: - algorithms = algorithms.filter(classifications__detection__occurrence__in=scope_queryset) - elif model is SourceImageCollection: - algorithms = algorithms.filter(classifications__detection__source_image__collections__in=scope_queryset) - return algorithms.distinct().order_by("name") + selected occurrences.""" + return ( + Algorithm.objects.filter( + task_type=AlgorithmTaskType.CLASSIFICATION.value, + classifications__detection__occurrence__in=scope_queryset, + ) + .distinct() + .order_by("name") + ) def to_config(self) -> dict: return { diff --git a/ami/ml/post_processing/tests/test_class_masking_admin.py b/ami/ml/post_processing/tests/test_class_masking_admin.py index 5880099af..92fa76029 100644 --- a/ami/ml/post_processing/tests/test_class_masking_admin.py +++ b/ami/ml/post_processing/tests/test_class_masking_admin.py @@ -202,3 +202,17 @@ def test_form_without_scope_offers_all_classifiers(self): offered = set(form.fields["algorithm_id"].queryset.values_list("pk", flat=True)) self.assertIn(self.used.pk, offered) self.assertIn(self.unused.pk, offered) + + def test_collection_scope_offers_all_classifiers(self): + """A collection scope intentionally keeps the full classifier list rather + than narrowing to the classifiers used in the collection. The narrowing + lookup is an unbounded DISTINCT over every classification in the + collection, which can time out while rendering the form on a large + collection. This pins that the collection path stays unfiltered so the + expensive filter is not re-added there by mistake.""" + collection = SourceImageCollection.objects.create(project=self.project, name="scope coll") + collection.images.add(self.source_image) + form = ClassMaskingActionForm(scope_queryset=SourceImageCollection.objects.filter(pk=collection.pk)) + offered = set(form.fields["algorithm_id"].queryset.values_list("pk", flat=True)) + self.assertIn(self.used.pk, offered) + self.assertIn(self.unused.pk, offered) From c5aa3b6638cbed3d5bdd9b79c53a3aaeea2b7055 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 30 Jun 2026 12:10:47 -0700 Subject: [PATCH 5/8] refactor(post-processing): split rank roll-up into a follow-up PR Rank roll-up moves out of this PR so class masking can merge on its own. Class masking is validated end to end against a real classifier and is well covered by tests; rank roll-up still has an open design question (whether the per-rank pick should be constrained to the source taxon's ancestors or remain a global distribution roll-up) and thinner tests. Keeping them together would force reviewers to either block ready class-masking work or sign off on the less-settled roll-up. This removes the rank-roll-up task, its admin form and action, its registry entry, and its tests. The roll-up feature returns in its own PR, stacked on this one, where the lineage decision and fuller tests can land before it merges. Co-Authored-By: Claude --- ami/main/admin.py | 12 - ami/ml/post_processing/__init__.py | 1 - .../post_processing/admin/rank_rollup_form.py | 13 -- ami/ml/post_processing/rank_rollup.py | 219 ------------------ ami/ml/post_processing/registry.py | 2 - .../tests/test_class_masking.py | 35 +-- .../tests/test_class_masking_admin.py | 39 +--- 7 files changed, 2 insertions(+), 319 deletions(-) delete mode 100644 ami/ml/post_processing/admin/rank_rollup_form.py delete mode 100644 ami/ml/post_processing/rank_rollup.py diff --git a/ami/main/admin.py b/ami/main/admin.py index 253224154..b25860507 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -15,10 +15,8 @@ from ami.ml.models.project_pipeline_config import ProjectPipelineConfig from ami.ml.post_processing.admin.actions import make_post_processing_action from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm -from ami.ml.post_processing.admin.rank_rollup_form import RankRollupActionForm from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm from ami.ml.post_processing.class_masking import ClassMaskingTask -from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask from ami.ml.tasks import remove_duplicate_classifications @@ -868,21 +866,11 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou f"Post-processing: {task_cls.name} on Capture Set {collection.pk}" ), ) - run_rank_rollup = make_post_processing_action( - RankRollupTask, - RankRollupActionForm, - scope_resolver=lambda collection: {"source_image_collection_id": collection.pk}, - name_resolver=lambda task_cls, collection: ( - f"Post-processing: {task_cls.name} on Capture Set {collection.pk}" - ), - ) - actions = [ populate_collection, populate_collection_async, run_small_size_filter, run_class_masking, - run_rank_rollup, ] # Hide images many-to-many field from form. This would list all source images in the database. diff --git a/ami/ml/post_processing/__init__.py b/ami/ml/post_processing/__init__.py index 8837973c4..c94be9ae9 100644 --- a/ami/ml/post_processing/__init__.py +++ b/ami/ml/post_processing/__init__.py @@ -1,3 +1,2 @@ from . import class_masking # noqa: F401 -from . import rank_rollup # noqa: F401 from . import small_size_filter # noqa: F401 diff --git a/ami/ml/post_processing/admin/rank_rollup_form.py b/ami/ml/post_processing/admin/rank_rollup_form.py deleted file mode 100644 index 1d072a9b6..000000000 --- a/ami/ml/post_processing/admin/rank_rollup_form.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm - - -class RankRollupActionForm(BasePostProcessingActionForm): - """Knob form for Rank rollup. - - Rank rollup runs with the per-rank score thresholds and rollup order defined - on ``RankRollupConfig``. There are no per-run knobs yet, so the form only - confirms the selected capture set(s); the empty ``cleaned_data`` lets the - schema apply its defaults. Threshold overrides can be added here later. - """ diff --git a/ami/ml/post_processing/rank_rollup.py b/ami/ml/post_processing/rank_rollup.py deleted file mode 100644 index 179d515d7..000000000 --- a/ami/ml/post_processing/rank_rollup.py +++ /dev/null @@ -1,219 +0,0 @@ -import logging -from collections import defaultdict - -import pydantic -from django.db import transaction -from django.utils import timezone - -from ami.main.models import Classification, Taxon -from ami.ml.models.algorithm import AlgorithmCategoryMap -from ami.ml.post_processing.base import BasePostProcessingTask - -logger = logging.getLogger(__name__) - -DEFAULT_THRESHOLDS = {"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4} -DEFAULT_ROLLUP_ORDER = ["SPECIES", "GENUS", "FAMILY"] - - -class RankRollupConfig(pydantic.BaseModel): - source_image_collection_id: int - # Minimum aggregated score required to roll a classification up to each rank. - thresholds: dict[str, float] = DEFAULT_THRESHOLDS - # Ranks to try, finest first; the first whose aggregated score clears its - # threshold wins. - rollup_order: list[str] = DEFAULT_ROLLUP_ORDER - - @pydantic.validator("thresholds") - def _validate_thresholds(cls, value: dict[str, float]) -> dict[str, float]: - normalised: dict[str, float] = {} - for rank, threshold in value.items(): - if not (0.0 < threshold <= 1.0): - raise ValueError(f"Threshold for {rank} must be in (0, 1]") - normalised[rank.upper()] = threshold - return normalised - - @pydantic.validator("rollup_order") - def _uppercase_order(cls, value: list[str]) -> list[str]: - return [rank.upper() for rank in value] - - class Config: - extra = "forbid" - - -def find_ancestor_by_parent_chain(taxon, target_rank: str): - """Climb up parent relationships until a taxon with the target rank is found.""" - if not taxon: - return None - - target_rank = target_rank.upper() - - current = taxon - while current: - if current.rank.upper() == target_rank: - return current - current = current.parent - - return None - - -class RankRollupTask(BasePostProcessingTask): - """Post-processing task that rolls up low-confidence classifications - to higher ranks using aggregated scores. - """ - - key = "rank_rollup" - name = "Rank rollup" - config_schema = RankRollupConfig - - def run(self) -> None: - config: RankRollupConfig = self.config # type: ignore[assignment] - job = self.job - self.logger.info(f"Starting {self.name} task for job {job.pk if job else 'N/A'}") - - collection_id = config.source_image_collection_id - thresholds = config.thresholds - rollup_order = config.rollup_order - - self.logger.info( - f"Config loaded: collection_id={collection_id}, thresholds={thresholds}, rollup_order={rollup_order}" - ) - - # select_related the per-row relations the loop touches (category_map for - # labels, detection/occurrence for the rollup write) so the body issues no - # per-classification queries. - qs = ( - Classification.objects.filter( - terminal=True, - taxon__isnull=False, - detection__source_image__collections__id=collection_id, - ) - .select_related("category_map", "taxon", "detection", "detection__occurrence") - .distinct() - ) - - total = qs.count() - self.logger.info(f"Found {total} terminal classifications to process for collection {collection_id}") - - # Pre-load every label across the distinct category maps in one pass (two - # queries total), instead of dereferencing clf.category_map per row. - category_map_ids = list(qs.values_list("category_map_id", flat=True).distinct()) - all_labels: set[str] = set() - for category_map in AlgorithmCategoryMap.objects.filter(pk__in=category_map_ids): - if category_map.labels: - all_labels.update(label for label in category_map.labels if label) - - label_to_taxon = {} - if all_labels: - for taxon in Taxon.objects.filter(name__in=all_labels).select_related("parent"): - label_to_taxon[taxon.name] = taxon - self.logger.info(f"Pre-loaded {len(label_to_taxon)} taxa from {len(all_labels)} unique labels") - - updated_occurrences = [] - - with transaction.atomic(): - for i, clf in enumerate(qs.iterator(), start=1): - score_str = f"{clf.score:.3f}" if clf.score is not None else "N/A" - self.logger.info(f"Processing classification #{clf.pk} (taxon={clf.taxon}, score={score_str})") - - if not clf.scores: - self.logger.info(f"Skipping classification #{clf.pk}: no scores available") - continue - if not clf.category_map: - self.logger.info(f"Skipping classification #{clf.pk}: no category_map assigned") - continue - - taxon_scores = defaultdict(float) - - for idx, score in enumerate(clf.scores): - label = clf.category_map.labels[idx] - if not label: - continue - - taxon = label_to_taxon.get(label) - if not taxon: - self.logger.debug(f"Skipping label '{label}' (no matching Taxon found)") - continue - - for rank in rollup_order: - ancestor = find_ancestor_by_parent_chain(taxon, rank) - if ancestor: - taxon_scores[ancestor] += score - self.logger.debug(f" + Added {score:.3f} to ancestor {ancestor.name} ({rank})") - - new_taxon = None - new_score = None - scores_str = {t.name: s for t, s in taxon_scores.items()} - self.logger.info(f"Aggregated taxon scores: {scores_str}") - # The candidates at each rank are every taxon that accumulated - # score there — the global argmax across the whole distribution, - # not only ancestors of clf.taxon. For a confident classification - # the winner is its own lineage, but a diffuse, low-confidence - # distribution can spread enough mass across unrelated branches - # that the top taxon at a rank is not an ancestor of clf.taxon, so - # the roll-up reparents the detection to an unrelated family. - # TODO: decide the intended semantics: - # - lineage-constrained: restrict candidates to ancestors of - # clf.taxon (find_ancestor_by_parent_chain already yields them) - # so a roll-up only ever generalizes the original prediction; or - # - distribution roll-up: keep the global argmax but document it, - # and reconsider whether applied_to -> this single clf is the - # right provenance when the result is outside its lineage. - for rank in rollup_order: - threshold = thresholds.get(rank, 1.0) - candidates = {t: s for t, s in taxon_scores.items() if t.rank == rank} - - if not candidates: - self.logger.info(f"No candidates found at rank {rank}") - continue - - best_taxon, best_score = max(candidates.items(), key=lambda kv: kv[1]) - self.logger.info( - f"Best at rank {rank}: {best_taxon.name} ({best_score:.3f}) [threshold={threshold}]" - ) - - if best_score >= threshold: - new_taxon, new_score = best_taxon, best_score - self.logger.info(f"Rollup decision: {new_taxon.name} ({rank}) with score {new_score:.3f}") - break - - if new_taxon and new_taxon != clf.taxon: - self.logger.info(f"Rolling up {clf.taxon} => {new_taxon} ({new_taxon.rank})") - - # Mark all classifications for this detection as non-terminal - Classification.objects.filter(detection=clf.detection).update(terminal=False) - Classification.objects.create( - detection=clf.detection, - taxon=new_taxon, - score=new_score, - terminal=True, - algorithm=self.algorithm, - timestamp=timezone.now(), - applied_to=clf, - ) - - occurrence = clf.detection.occurrence - if occurrence: - occurrence.save(update_determination=True) - updated_occurrences.append(occurrence) - self.logger.info( - f"Rolled up occurrence {occurrence.pk}: {clf.taxon} => {new_taxon} " - f"({new_taxon.rank}) with rolled-up score={new_score:.3f}" - ) - else: - self.logger.warning(f"Detection #{clf.detection.pk} has no occurrence; skipping.") - else: - self.logger.info(f"No rollup applied for classification #{clf.pk} (taxon={clf.taxon})") - - # Update progress every 10 iterations - if i % 10 == 0 or i == total: - progress = i / total if total > 0 else 1.0 - self.update_progress(progress) - - self.report_stage_metrics( - { - "classifications_checked": total, - "occurrences_rolled_up": len(updated_occurrences), - } - ) - self.logger.info(f"Rank rollup completed. Updated {len(updated_occurrences)} occurrences.") - self.logger.info(f"{self.name} task finished for collection {collection_id}.") diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py index 28fa7fb2f..308be18ae 100644 --- a/ami/ml/post_processing/registry.py +++ b/ami/ml/post_processing/registry.py @@ -1,12 +1,10 @@ # Registry of available post-processing tasks from ami.ml.post_processing.class_masking import ClassMaskingTask -from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask POSTPROCESSING_TASKS = { SmallSizeFilterTask.key: SmallSizeFilterTask, ClassMaskingTask.key: ClassMaskingTask, - RankRollupTask.key: RankRollupTask, } diff --git a/ami/ml/post_processing/tests/test_class_masking.py b/ami/ml/post_processing/tests/test_class_masking.py index 4503a3d49..b56e881f4 100644 --- a/ami/ml/post_processing/tests/test_class_masking.py +++ b/ami/ml/post_processing/tests/test_class_masking.py @@ -1,4 +1,4 @@ -"""Domain tests for class masking and rank rollup post-processing tasks. +"""Domain tests for the class masking post-processing task. Class masking re-scores a classifier's terminal predictions against a taxa list: classes whose taxon is not in the list are masked, the softmax is renormalised @@ -29,7 +29,6 @@ from ami.ml.models import Algorithm, AlgorithmCategoryMap from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.post_processing.class_masking import ClassMaskingTask, make_classifications_filtered_by_taxa_list -from ami.ml.post_processing.rank_rollup import RankRollupTask from ami.tests.fixtures.main import create_taxa, setup_test_project @@ -268,35 +267,3 @@ def test_task_run_occurrence_scope(self): new_clf = Classification.objects.filter(detection=det, terminal=True).exclude(algorithm=self.algorithm).first() self.assertIsNotNone(new_clf) self.assertEqual(new_clf.taxon, self.species_taxa[0]) - - # ----- rank rollup ---------------------------------------------------- - - def test_rank_rollup_creates_genus_terminal_classification(self): - now = datetime.datetime.now(datetime.timezone.utc) - originals = [] - for _ in range(3): - det, _occ = self._detection_with_occurrence() - originals.append( - Classification.objects.create( - detection=det, - taxon=self.species_taxon, - score=0.5, - scores=[0.5, 0.2, 0.1], - terminal=True, - timestamp=now, - algorithm=self.algorithm, - ) - ) - - RankRollupTask( - source_image_collection_id=self.collection.pk, - thresholds={"SPECIES": 0.8, "GENUS": 0.6, "FAMILY": 0.4}, - ).run() - - for original in originals: - original.refresh_from_db(fields=["terminal"]) - self.assertFalse(original.terminal) - rolled = Classification.objects.filter(detection=original.detection, terminal=True).first() - self.assertIsNotNone(rolled) - self.assertEqual(rolled.taxon, self.genus_taxon) - self.assertEqual(rolled.applied_to, original) diff --git a/ami/ml/post_processing/tests/test_class_masking_admin.py b/ami/ml/post_processing/tests/test_class_masking_admin.py index 92fa76029..903248cf5 100644 --- a/ami/ml/post_processing/tests/test_class_masking_admin.py +++ b/ami/ml/post_processing/tests/test_class_masking_admin.py @@ -1,4 +1,4 @@ -"""Schema validation + admin-action wiring tests for class masking and rank rollup. +"""Schema validation + admin-action wiring tests for class masking. These are deliberately lightweight: they exercise the pydantic config contracts and the admin trigger flow (intermediate page -> Job creation with the right @@ -26,7 +26,6 @@ from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm from ami.ml.post_processing.class_masking import ClassMaskingConfig -from ami.ml.post_processing.rank_rollup import RankRollupConfig from ami.users.models import User @@ -57,24 +56,6 @@ def test_extra_field_is_forbidden(self): ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3, bogus=1) -class TestRankRollupConfig(TestCase): - def test_defaults_applied(self): - config = RankRollupConfig(source_image_collection_id=1) - self.assertEqual(config.thresholds["SPECIES"], 0.8) - self.assertEqual(config.rollup_order, ["SPECIES", "GENUS", "FAMILY"]) - - def test_threshold_out_of_range_is_invalid(self): - with self.assertRaises(pydantic.ValidationError): - RankRollupConfig(source_image_collection_id=1, thresholds={"SPECIES": 1.5}) - - def test_threshold_and_order_are_uppercased(self): - config = RankRollupConfig( - source_image_collection_id=1, thresholds={"species": 0.7}, rollup_order=["species", "genus"] - ) - self.assertIn("SPECIES", config.thresholds) - self.assertEqual(config.rollup_order, ["SPECIES", "GENUS"]) - - class _PostProcessingAdminCase(TestCase): @classmethod def setUpTestData(cls) -> None: @@ -151,24 +132,6 @@ def test_valid_post_on_occurrence_creates_occurrence_scoped_job(self): self.assertIsNone(job.params["config"].get("source_image_collection_id")) -class TestRankRollupAdmin(_PostProcessingAdminCase): - def test_valid_post_creates_rank_rollup_job_with_defaults(self): - url = reverse("admin:main_sourceimagecollection_changelist") - response = self.client.post( - url, - data={ - "action": "run_rank_rollup", - django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.collection.pk)], - "confirm": "yes", - }, - ) - self.assertEqual(response.status_code, 302) - job = Job.objects.get(project=self.project, job_type_key="post_processing") - self.assertEqual(job.params["task"], "rank_rollup") - self.assertEqual(job.params["config"]["source_image_collection_id"], self.collection.pk) - self.assertEqual(job.params["config"]["thresholds"]["SPECIES"], 0.8) - - class TestClassMaskingFormScopeFiltering(TestCase): """The class-mask form offers only classifiers that actually produced classifications within the selected scope, so an operator cannot pick an From 194efcad5842a6ca2965e63389e48fc56dd6be10 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Jul 2026 08:37:46 -0700 Subject: [PATCH 6/8] feat(post-processing): make class masking survive the Job path at scale, add reweight toggle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Class masking now commits in batches and reports progress and stage metrics per batch, instead of building the whole scope in memory and committing once at the end. On a large classifier (tens of thousands of classes, thousands of classifications) the single final commit could exhaust memory, and run through the admin/Job path the task was revoked by the job health-check reaper before it committed — so nothing persisted and the occurrences-changed metrics never reached the job page. It now mirrors SmallSizeFilterTask: process in batches, commit each batch in its own transaction, and call update_progress() plus report_stage_metrics() per batch. Memory stays bounded (the source cursor and the per-batch buffers are capped at batch_size), the heartbeat keeps the reaper satisfied, and the live counts appear on the job page. occurrences_updated now counts only occurrences whose determination actually changed, matching the size filter. Also add a reweight toggle (default on) to the config and the admin form. Renormalising the kept classes to sum to 1 does not change which species is chosen — argmax over the kept classes is identical either way — so it affects only the stored confidence. With reweighting off, the kept classes keep their original absolute scores, which is useful for evaluating the model's raw confidence. Co-Authored-By: Claude --- .../admin/class_masking_form.py | 10 + ami/ml/post_processing/class_masking.py | 199 ++++++++++++------ .../tests/test_class_masking.py | 175 +++++++++++++++ .../tests/test_class_masking_admin.py | 43 ++++ 4 files changed, 358 insertions(+), 69 deletions(-) diff --git a/ami/ml/post_processing/admin/class_masking_form.py b/ami/ml/post_processing/admin/class_masking_form.py index a290350ce..6e2f5f884 100644 --- a/ami/ml/post_processing/admin/class_masking_form.py +++ b/ami/ml/post_processing/admin/class_masking_form.py @@ -30,6 +30,15 @@ class ClassMaskingActionForm(BasePostProcessingActionForm): "classification's softmax is renormalised over the classes that remain." ), ) + reweight = forms.BooleanField( + required=False, + initial=True, + label="Reweight (renormalise) scores", + help_text=( + "Renormalise the kept classes to sum to 1. " + "Off = keep the model's raw absolute scores; the chosen species is unchanged either way." + ), + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -63,4 +72,5 @@ def to_config(self) -> dict: return { "algorithm_id": self.cleaned_data["algorithm_id"].pk, "taxa_list_id": self.cleaned_data["taxa_list_id"].pk, + "reweight": self.cleaned_data["reweight"], } diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py index 119cda191..2043fa728 100644 --- a/ami/ml/post_processing/class_masking.py +++ b/ami/ml/post_processing/class_masking.py @@ -25,6 +25,10 @@ class ClassMaskingConfig(pydantic.BaseModel): taxa_list_id: int # The source classifier whose terminal classifications are re-scored. algorithm_id: int + # When True (default), renormalise the kept classes' scores to sum to 1 after + # masking. When False, the kept classes retain their original absolute scores and + # the excluded classes are zeroed; the chosen species is identical either way. + reweight: bool = True @pydantic.root_validator(skip_on_failure=True) def _exactly_one_scope(cls, values: dict) -> dict: @@ -43,18 +47,30 @@ def make_classifications_filtered_by_taxa_list( algorithm: Algorithm, new_algorithm: Algorithm, *, + batch_size: int = 200, + reweight: bool = True, task_logger: logging.Logger = logger, - progress_callback: Callable[[int, int], None] | None = None, + on_batch: Callable[[dict], None] | None = None, ) -> dict[str, int]: """Re-score ``classifications`` by masking out classes absent from ``taxa_list``. For each terminal classification produced by ``algorithm``, the logits of classes whose taxon is not in ``taxa_list`` are masked, the softmax is - renormalised over the remaining classes, and a new terminal classification + renormalised over the remaining classes (or, when ``reweight=False``, the kept + classes retain their original absolute scores), and a new terminal classification (attributed to ``new_algorithm``, linked back via ``applied_to``) records the masked prediction. The original classification is demoted to non-terminal. - Returns counters (checked / masked / occurrences updated) for stage metrics. + Commits in batches of ``batch_size`` so memory stays bounded and the job + health-check reaper sees regular heartbeats. ``on_batch`` is called after every + flush with running counters: + ``{"classifications_checked": i, "classifications_total": total, + "classifications_masked": masked_count, "occurrences_updated": n_changed}``. + + ``occurrences_updated`` counts only occurrences whose determination actually + changed (not just any occurrence touched), matching the size-filter convention. + + Returns final counters (checked / masked / occurrences updated) for stage metrics. """ taxa_in_list = set(taxa_list.taxa.all()) @@ -91,86 +107,119 @@ def make_classifications_filtered_by_taxa_list( classifications_to_demote: list[Classification] = [] classifications_to_add: list[Classification] = [] occurrences_to_update: set[Occurrence] = set() + # Tracks occurrences whose determination actually changed across all batches. + changed_occurrence_ids: set[int] = set() timestamp = timezone.now() masked_count = 0 - for i, classification in enumerate(classifications.iterator(), start=1): + + for i, classification in enumerate(classifications.iterator(chunk_size=batch_size), start=1): scores, logits = classification.scores, classification.logits if not isinstance(logits, list) or not all(isinstance(x, (int, float)) for x in logits): raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") - if len(logits) != num_categories: + elif len(logits) != num_categories: task_logger.warning( f"Classification {classification.pk}: {len(logits)} logits != {num_categories} categories; skipping" ) - continue - - # Mask excluded classes with -inf on a working copy so the renormalised - # softmax assigns them exactly zero probability — an excluded class can - # never win argmax. (-inf is compute-only; it is never stored, since it - # is not valid JSON. The stored vectors stay finite: see below.) - working = np.asarray(logits, dtype=float) - working[excluded_indices] = -np.inf - working -= working.max() # max is over kept classes (finite); stabilises exp - exp = np.exp(working) # exp(-inf) == 0 for masked classes - new_scores_np = exp / exp.sum() # sum > 0: at least one class is kept - top_index = int(np.argmax(new_scores_np)) - new_scores = new_scores_np.tolist() - - # No-change short-circuit: if masking shifted no probability (the classes - # this taxa list drops carried ~zero score here), leave the row untouched. - if isinstance(scores, list) and np.allclose(scores, new_scores, atol=1e-9): - task_logger.debug(f"Classification {classification.pk} unchanged by masking; skipping") - continue - - top_taxon = index_to_taxon.get(top_index) # guaranteed in taxa_in_list (top_index is kept) - - classification.terminal = False - classification.updated_at = timestamp - - new_classification = Classification( - detection=classification.detection, - taxon=top_taxon, - algorithm=new_algorithm, - category_map=new_algorithm.category_map, - score=float(new_scores_np[top_index]), - scores=new_scores, - # Store the raw logits unchanged (JSON-safe): the mask is fully captured - # by ``scores`` (dropped classes -> 0) and the ``applied_to`` lineage. - logits=logits, - terminal=True, - timestamp=classification.timestamp, - applied_to=classification, - created_at=timestamp, - updated_at=timestamp, - ) - classifications_to_demote.append(classification) - classifications_to_add.append(new_classification) - masked_count += 1 - - detection = classification.detection - if detection is not None and detection.occurrence is not None: - occurrences_to_update.add(detection.occurrence) - - if progress_callback is not None and (i % 100 == 0 or i == total): - progress_callback(i, total) - - with transaction.atomic(): - if classifications_to_demote: - Classification.objects.bulk_update(classifications_to_demote, ["terminal", "updated_at"]) - if classifications_to_add: - Classification.objects.bulk_create(classifications_to_add) - # Recompute each affected occurrence's determination from its new terminal - # classification. - for occurrence in occurrences_to_update: - occurrence.save(update_determination=True) + else: + # Mask excluded classes with -inf on a working copy so the renormalised + # softmax assigns them exactly zero probability — an excluded class can + # never win argmax. (-inf is compute-only; it is never stored, since it + # is not valid JSON. The stored vectors stay finite: see below.) + working = np.asarray(logits, dtype=float) + working[excluded_indices] = -np.inf + working -= working.max() # max is over kept classes (finite); stabilises exp + exp = np.exp(working) # exp(-inf) == 0 for masked classes + new_scores_np = exp / exp.sum() # sum > 0: at least one class is kept + top_index = int(np.argmax(new_scores_np)) + + if reweight: + # Renormalise: excluded classes become exactly 0; kept classes sum to 1. + new_scores = new_scores_np.tolist() + score = float(new_scores_np[top_index]) + else: + # No renormalisation: kept classes keep their original absolute scores; + # excluded classes are zeroed. The winner is the same (argmax is + # computed on the renormalised distribution before this branch). + new_scores = list(scores) + for idx in excluded_indices: + new_scores[idx] = 0.0 + score = float(new_scores[top_index]) + + # No-change short-circuit: if masking shifted no probability (the classes + # this taxa list drops carried ~zero score here), leave the row untouched. + if isinstance(scores, list) and np.allclose(scores, new_scores, atol=1e-9): + task_logger.debug(f"Classification {classification.pk} unchanged by masking; skipping") + else: + top_taxon = index_to_taxon.get(top_index) # guaranteed in taxa_in_list (top_index is kept) + + classification.terminal = False + classification.updated_at = timestamp + + new_classification = Classification( + detection=classification.detection, + taxon=top_taxon, + algorithm=new_algorithm, + category_map=new_algorithm.category_map, + score=score, + scores=new_scores, + # Store the raw logits unchanged (JSON-safe): the mask is fully captured + # by ``scores`` (dropped classes -> 0) and the ``applied_to`` lineage. + logits=logits, + terminal=True, + timestamp=classification.timestamp, + applied_to=classification, + created_at=timestamp, + updated_at=timestamp, + ) + classifications_to_demote.append(classification) + classifications_to_add.append(new_classification) + masked_count += 1 + + detection = classification.detection + if detection is not None and detection.occurrence is not None: + occurrences_to_update.add(detection.occurrence) + + # Flush every batch_size items and at the final item. The flush fires even + # when nothing was accumulated so the job health-check sees a heartbeat during + # stretches where masking is a no-op (all short-circuited). + if i % batch_size == 0 or i == total: + with transaction.atomic(): + if classifications_to_demote: + Classification.objects.bulk_update(classifications_to_demote, ["terminal", "updated_at"]) + if classifications_to_add: + Classification.objects.bulk_create(classifications_to_add) + # Count an occurrence only when saving its new terminal classification + # actually changes the determination. Re-saving recomputes it in place, + # so an occurrence pinned to a human identification keeps its taxon + # and must not inflate the metric. + for occurrence in occurrences_to_update: + prev = occurrence.determination_id + occurrence.save(update_determination=True) + if occurrence.pk is not None and occurrence.determination_id != prev: + changed_occurrence_ids.add(occurrence.pk) + + classifications_to_demote.clear() + classifications_to_add.clear() + occurrences_to_update.clear() + + if on_batch is not None: + on_batch( + { + "classifications_checked": i, + "classifications_total": total, + "classifications_masked": masked_count, + "occurrences_updated": len(changed_occurrence_ids), + } + ) task_logger.info( - f"Re-scored {masked_count} of {total} classifications; updated {len(occurrences_to_update)} occurrences." + f"Re-scored {masked_count} of {total} classifications; updated {len(changed_occurrence_ids)} occurrences." ) return { "classifications_checked": total, "classifications_masked": masked_count, - "occurrences_updated": len(occurrences_to_update), + "occurrences_updated": len(changed_occurrence_ids), } @@ -255,13 +304,25 @@ def run(self) -> None: classifications, scope_desc = self._scoped_classifications(config, source_algorithm) self.logger.info(f"Applying class masking on {scope_desc} using taxa list {taxa_list.pk}") + def _on_batch(m: dict) -> None: + total = m["classifications_total"] + self.update_progress(m["classifications_checked"] / total if total else 1.0) + self.report_stage_metrics( + { + "classifications_checked": m["classifications_checked"], + "classifications_masked": m["classifications_masked"], + "occurrences_updated": m["occurrences_updated"], + } + ) + metrics = make_classifications_filtered_by_taxa_list( classifications=classifications, taxa_list=taxa_list, algorithm=source_algorithm, new_algorithm=masking_algorithm, + reweight=config.reweight, task_logger=self.logger, - progress_callback=lambda i, total: self.update_progress(i / total if total else 1.0), + on_batch=_on_batch, ) self.report_stage_metrics(metrics) self.logger.info(f"=== Completed {self.name} ===") diff --git a/ami/ml/post_processing/tests/test_class_masking.py b/ami/ml/post_processing/tests/test_class_masking.py index b56e881f4..887cf71c9 100644 --- a/ami/ml/post_processing/tests/test_class_masking.py +++ b/ami/ml/post_processing/tests/test_class_masking.py @@ -267,3 +267,178 @@ def test_task_run_occurrence_scope(self): new_clf = Classification.objects.filter(detection=det, terminal=True).exclude(algorithm=self.algorithm).first() self.assertIsNotNone(new_clf) self.assertEqual(new_clf.taxon, self.species_taxa[0]) + + # ----- batched commit + heartbeat ------------------------------------- + + def test_batched_commit_flushes_multiple_times(self): + """With batch_size=2 and 5 classifications, on_batch fires at i=2, 4, 5 — more than once. + + Verifies that all classifications are correctly masked across flush boundaries + (correctness preserved) and that the on_batch callback receives a call per batch + (heartbeat wiring works).""" + taxa_list = TaxaList.objects.create(name="Batch flush test") + taxa_list.taxa.set(self.species_taxa[:1]) # keep only index 0; indices 1 and 2 excluded + + new_algorithm = Algorithm.objects.create( + name="masked_batch", + key="masked_batch_test", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + + # Create 5 detections, each with a classification whose top logit is index 2 + # (excluded), so every one should be masked and re-assigned to index 0. + logits = [2.0, 1.0, 5.0] + pks = [] + for _ in range(5): + det, _ = self._detection_with_occurrence() + clf = self._create_classification_with_logits(det, self.species_taxa[2], _softmax(logits), logits) + pks.append(clf.pk) + + batch_calls: list[dict] = [] + metrics = make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk__in=pks), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + batch_size=2, + on_batch=batch_calls.append, + ) + + self.assertGreater(len(batch_calls), 1, "on_batch must fire more than once with batch_size=2 over 5 items") + # i=2 and i=4 fire on the batch boundary; i=5 fires on the total boundary. + self.assertEqual(len(batch_calls), 3) + self.assertEqual(metrics["classifications_masked"], 5, "All 5 classifications must be masked across batches") + self.assertEqual( + Classification.objects.filter(algorithm=new_algorithm, terminal=True).count(), + 5, + "A new terminal classification must exist for every masked row", + ) + + # ----- changed-only occurrence count ---------------------------------- + + def test_occurrences_updated_counts_only_changed_determinations(self): + """``occurrences_updated`` counts only occurrences whose determination changed, + not every occurrence whose detection was touched. + + occ1: original winner is index 2 (excluded) — masking flips determination to index 0. + occ2: original winner is index 0 (kept) — masking reassigns scores but determination stays index 0. + Only occ1 should count.""" + taxa_list = TaxaList.objects.create(name="Changed-only count test") + taxa_list.taxa.set(self.species_taxa[:2]) # excludes species_taxa[2] (index 2) + + new_algorithm = Algorithm.objects.create( + name="masked_changed", + key="masked_changed_test", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + + # occ1: index 2 has the highest logit, so masking forces the winner to index 0. + logits_changes = [2.0, 1.0, 5.0] + det1, occ1 = self._detection_with_occurrence() + clf1 = self._create_classification_with_logits( + det1, self.species_taxa[2], _softmax(logits_changes), logits_changes + ) + # Persist the pre-masking determination so the loaded instance has it set. + occ1.save(update_determination=True) # determination = species_taxa[2] + + # occ2: index 0 has the highest logit, so after masking it remains the winner. + # Masking still changes the scores (index 2 drops to 0, softmax shifts), so the + # classification IS demoted — but occ2's determination stays species_taxa[0]. + logits_stays = [5.0, 1.0, 3.0] + det2, occ2 = self._detection_with_occurrence() + clf2 = self._create_classification_with_logits( + det2, self.species_taxa[0], _softmax(logits_stays), logits_stays + ) + occ2.save(update_determination=True) # determination = species_taxa[0] + + metrics = make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk__in=[clf1.pk, clf2.pk]), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_algorithm, + ) + + self.assertEqual(metrics["classifications_masked"], 2, "Both classifications are modified by masking") + self.assertEqual( + metrics["occurrences_updated"], + 1, + "Only the occurrence whose determination changed (occ1) counts", + ) + + # ----- reweight toggle ------------------------------------------------ + + def test_reweight_false_winner_identical_scores_differ(self): + """With reweight=False the winning taxon is identical to reweight=True, but the + stored score equals the winner's original absolute probability (not renormalised) + and the kept-class scores do not sum to 1. + + logits = [2, 1, 5]: index 2 is excluded, so index 0 wins in both modes. + reweight=True: new_scores renormalised — sums to 1, score = renormalised p(index 0). + reweight=False: new_scores = original with index 2 zeroed — sums < 1, score = original p(index 0).""" + logits = [2.0, 1.0, 5.0] + scores = _softmax(logits) + taxa_list = TaxaList.objects.create(name="Reweight compare") + taxa_list.taxa.set(self.species_taxa[:2]) # excludes index 2 + + # --- reweight=True (default) --- + det_t, _ = self._detection_with_occurrence() + clf_t = self._create_classification_with_logits(det_t, self.species_taxa[2], scores, logits) + new_alg_true = Algorithm.objects.create( + name="masked_rw_true", + key="masked_rw_true_test", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk=clf_t.pk), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_alg_true, + reweight=True, + ) + + # --- reweight=False --- + det_f, _ = self._detection_with_occurrence() + clf_f = self._create_classification_with_logits(det_f, self.species_taxa[2], scores, logits) + new_alg_false = Algorithm.objects.create( + name="masked_rw_false", + key="masked_rw_false_test", + task_type=AlgorithmTaskType.CLASSIFICATION.value, + category_map=self.algorithm.category_map, + ) + make_classifications_filtered_by_taxa_list( + classifications=Classification.objects.filter(pk=clf_f.pk), + taxa_list=taxa_list, + algorithm=self.algorithm, + new_algorithm=new_alg_false, + reweight=False, + ) + + new_clf_t = Classification.objects.get(detection=det_t, terminal=True) + new_clf_f = Classification.objects.get(detection=det_f, terminal=True) + + # Winner is the same in both modes. + self.assertEqual(new_clf_t.taxon, self.species_taxa[0]) + self.assertEqual(new_clf_f.taxon, self.species_taxa[0], "Winner is identical with reweight=False") + + # Excluded class is zeroed in both modes. + self.assertAlmostEqual(new_clf_t.scores[2], 0.0, places=10) + self.assertAlmostEqual(new_clf_f.scores[2], 0.0, places=10) + + # reweight=True: kept-class scores are renormalised and sum to 1. + self.assertAlmostEqual(sum(new_clf_t.scores), 1.0, places=5) + + # reweight=False: kept classes retain original absolute values; sum is < 1. + self.assertAlmostEqual( + new_clf_f.scores[0], scores[0], places=6, msg="Kept class retains original score with reweight=False" + ) + self.assertAlmostEqual(new_clf_f.scores[1], scores[1], places=6) + self.assertNotAlmostEqual( + sum(new_clf_f.scores), 1.0, places=2, msg="Scores must not sum to 1 with reweight=False" + ) + + # Stored confidence: reweight=False uses the original pre-mask probability. + self.assertAlmostEqual(new_clf_f.score, scores[0], places=6) + self.assertNotAlmostEqual(new_clf_t.score, scores[0], places=2) diff --git a/ami/ml/post_processing/tests/test_class_masking_admin.py b/ami/ml/post_processing/tests/test_class_masking_admin.py index 903248cf5..4664f52d3 100644 --- a/ami/ml/post_processing/tests/test_class_masking_admin.py +++ b/ami/ml/post_processing/tests/test_class_masking_admin.py @@ -55,6 +55,14 @@ def test_extra_field_is_forbidden(self): with self.assertRaises(pydantic.ValidationError): ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3, bogus=1) + def test_reweight_defaults_to_true(self): + config = ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3) + self.assertTrue(config.reweight) + + def test_reweight_can_be_set_false(self): + config = ClassMaskingConfig(source_image_collection_id=1, taxa_list_id=2, algorithm_id=3, reweight=False) + self.assertFalse(config.reweight) + class _PostProcessingAdminCase(TestCase): @classmethod @@ -132,6 +140,41 @@ def test_valid_post_on_occurrence_creates_occurrence_scoped_job(self): self.assertIsNone(job.params["config"].get("source_image_collection_id")) +class TestClassMaskingFormReweight(_PostProcessingAdminCase): + """The admin form exposes a reweight toggle and passes it through to the job config.""" + + def _post_collection_reweight(self, include_reweight: bool): + data = { + "action": "run_class_masking", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.collection.pk)], + "confirm": "yes", + "taxa_list_id": str(self.taxa_list.pk), + "algorithm_id": str(self.algorithm.pk), + } + if include_reweight: + data["reweight"] = "on" + return self.client.post(reverse("admin:main_sourceimagecollection_changelist"), data=data) + + def test_form_has_reweight_field(self): + form = ClassMaskingActionForm() + self.assertIn("reweight", form.fields) + self.assertTrue(form.fields["reweight"].initial, "reweight must default to True (checked)") + + def test_to_config_includes_reweight_true_when_checked(self): + """When the operator checks the reweight box, the job config carries reweight=True.""" + response = self._post_collection_reweight(include_reweight=True) + self.assertEqual(response.status_code, 302) + job = Job.objects.get(project=self.project, job_type_key="post_processing") + self.assertTrue(job.params["config"]["reweight"]) + + def test_to_config_includes_reweight_false_when_unchecked(self): + """An unchecked reweight box (no value in POST) yields reweight=False in the job config.""" + response = self._post_collection_reweight(include_reweight=False) + self.assertEqual(response.status_code, 302) + job = Job.objects.get(project=self.project, job_type_key="post_processing") + self.assertFalse(job.params["config"]["reweight"]) + + class TestClassMaskingFormScopeFiltering(TestCase): """The class-mask form offers only classifiers that actually produced classifications within the selected scope, so an operator cannot pick an From 3072dae1564bb0cceb661e8edf304c6aa35a1ee2 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Jul 2026 09:19:04 -0700 Subject: [PATCH 7/8] fix(post-processing): bump updated_at on progress saves so the reaper spares long runs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The stale-job reaper (check_stale_jobs) revokes running jobs whose updated_at is older than STALLED_JOBS_MAX_MINUTES (10). The post-processing progress heartbeat saved with update_fields=["progress"], and Django does not auto-add auto_now fields to a narrowed update_fields, so updated_at never moved. Any post-processing run longer than 10 minutes looked frozen and was revoked mid-flight even while streaming progress and stage metrics — observed on a large class-masking run (revoked at 42%, ~1600/3814 classifications). Add updated_at to both save paths in BasePostProcessingTask.update_progress and report_stage_metrics, mirroring the async-job progress save in jobs/models.py that already does this for the same reason. Add a regression test that freezes updated_at past the cutoff and asserts each heartbeat path drags it forward. Co-Authored-By: Claude --- ami/ml/post_processing/base.py | 10 +++++-- ami/ml/tests.py | 48 ++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ami/ml/post_processing/base.py b/ami/ml/post_processing/base.py index 863312344..8f197f192 100644 --- a/ami/ml/post_processing/base.py +++ b/ami/ml/post_processing/base.py @@ -75,7 +75,11 @@ def update_progress(self, progress: float): if self.job: self.job.progress.update_stage(self.job.job_type_key, progress=progress) - self.job.save(update_fields=["progress"]) + # Bump updated_at alongside progress: the stale-job reaper + # (check_stale_jobs) revokes running jobs whose updated_at is older + # than STALLED_JOBS_MAX_MINUTES. A long post-processing run that only + # touched "progress" would look frozen and be reaped mid-flight. + self.job.save(update_fields=["progress", "updated_at"]) else: # No job object — fallback to plain logging @@ -99,7 +103,9 @@ def report_stage_metrics(self, metrics: dict[str, Any]): stage_key = self.job.job_type_key for label, value in metrics.items(): self.job.progress.add_or_update_stage_param(stage_key, label, value) - self.job.save(update_fields=["progress"]) + # Bump updated_at so the stale-job reaper sees an actively-progressing + # run; see update_progress for the full reasoning. + self.job.save(update_fields=["progress", "updated_at"]) @abc.abstractmethod def run(self) -> None: diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 6781f80ee..cb88ee6fd 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1603,6 +1603,54 @@ def test_run_reports_stage_metrics_on_job(self): # count equals the detection count. self.assertEqual(params.get("occurrences_updated"), total) + def test_progress_save_bumps_updated_at_for_reaper(self): + """A progress heartbeat bumps ``Job.updated_at`` so the stale-job reaper + leaves an actively-running post-processing job alone. + + ``check_stale_jobs`` revokes running jobs whose ``updated_at`` is older + than ``STALLED_JOBS_MAX_MINUTES``. The progress save narrows to + ``update_fields``, and Django does not auto-add ``auto_now`` fields to + that list, so ``update_progress`` / ``report_stage_metrics`` must include + ``updated_at`` explicitly. Without it a long run looks frozen and is + reaped mid-flight even while streaming progress. This pins that both save + paths move ``updated_at`` forward. + """ + from ami.jobs.models import Job + + job = Job.objects.create( + project=self.project, + name="reaper heartbeat test", + job_type_key="post_processing", + params={ + "task": "small_size_filter", + "config": {"source_image_collection_id": self.collection.pk, "size_threshold": 0.01}, + }, + ) + job.progress.add_stage("Post Processing", key="post_processing") + job.save() + + task = SmallSizeFilterTask( + job=job, + source_image_collection_id=self.collection.pk, + size_threshold=0.01, + ) + + # Freeze a baseline older than the reaper cutoff, then confirm each + # heartbeat path drags updated_at back to "now". USE_TZ is False, so + # updated_at is naive local time — mirror check_stale_jobs' own + # naive datetime.now() comparison. + stale = datetime.datetime.now() - datetime.timedelta(minutes=Job.STALLED_JOBS_MAX_MINUTES + 5) + + Job.objects.filter(pk=job.pk).update(updated_at=stale) + task.update_progress(0.5) + job.refresh_from_db() + self.assertGreater(job.updated_at, stale, "update_progress must bump updated_at") + + Job.objects.filter(pk=job.pk).update(updated_at=stale) + task.report_stage_metrics({"classifications_checked": 1}) + job.refresh_from_db() + self.assertGreater(job.updated_at, stale, "report_stage_metrics must bump updated_at") + def test_occurrences_updated_counts_only_changed_determinations(self): """``occurrences_updated`` counts occurrences whose determination actually changed, not every occurrence the filter re-saved. From 1afe75acd82e531e5144c943dc8e4b605129f530 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Jul 2026 16:17:47 -0700 Subject: [PATCH 8/8] =?UTF-8?q?fix(post-processing):=20address=20class-mas?= =?UTF-8?q?king=20review=20=E2=80=94=20logits-only=20scoring,=20reweight?= =?UTF-8?q?=20in=20algorithm=20identity,=20dry-run=20filter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three CodeRabbit findings on the class-masking task: - Derive all masked scores from `logits` alone. The reweight=False branch and the no-change short-circuit still read the stored `scores` field; recompute the model's full softmax from logits instead. The `scores` field is being retired, so masking must not depend on it. Mathematically identical while `scores == softmax(logits)` (the winner is unchanged), and it drops the last read of `scores` from the masking math. (Argmax stays on the masked logits — an excluded class still cannot win.) - Include the reweight mode in the masking algorithm's identity. reweight=True and reweight=False persist different score semantics, so keying the output Algorithm on (source algorithm, taxa list) alone made both modes share one row and lose provenance. The key/name now carry a "reweighted" / "absolute" suffix; a new test pins that the two modes resolve to distinct algorithms. - Match the management command's dry-run count to the task scope by also requiring `logits__isnull=False`, so `--dry-run` and the "no classifications" guard no longer overstate what the run will process. Co-Authored-By: Claude --- .../management/commands/run_class_masking.py | 1 + ami/ml/post_processing/class_masking.py | 81 +++++++++++-------- .../tests/test_class_masking.py | 50 +++++++++++- 3 files changed, 94 insertions(+), 38 deletions(-) diff --git a/ami/ml/management/commands/run_class_masking.py b/ami/ml/management/commands/run_class_masking.py index a99c0cb41..940956c6e 100644 --- a/ami/ml/management/commands/run_class_masking.py +++ b/ami/ml/management/commands/run_class_masking.py @@ -52,6 +52,7 @@ def handle(self, *args, **options): terminal=True, algorithm=algorithm, scores__isnull=False, + logits__isnull=False, ) .distinct() .count() diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py index 2043fa728..ffc77f346 100644 --- a/ami/ml/post_processing/class_masking.py +++ b/ami/ml/post_processing/class_masking.py @@ -114,7 +114,7 @@ def make_classifications_filtered_by_taxa_list( masked_count = 0 for i, classification in enumerate(classifications.iterator(chunk_size=batch_size), start=1): - scores, logits = classification.scores, classification.logits + logits = classification.logits if not isinstance(logits, list) or not all(isinstance(x, (int, float)) for x in logits): raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}") elif len(logits) != num_categories: @@ -122,33 +122,35 @@ def make_classifications_filtered_by_taxa_list( f"Classification {classification.pk}: {len(logits)} logits != {num_categories} categories; skipping" ) else: - # Mask excluded classes with -inf on a working copy so the renormalised - # softmax assigns them exactly zero probability — an excluded class can - # never win argmax. (-inf is compute-only; it is never stored, since it - # is not valid JSON. The stored vectors stay finite: see below.) - working = np.asarray(logits, dtype=float) - working[excluded_indices] = -np.inf - working -= working.max() # max is over kept classes (finite); stabilises exp - exp = np.exp(working) # exp(-inf) == 0 for masked classes - new_scores_np = exp / exp.sum() # sum > 0: at least one class is kept - top_index = int(np.argmax(new_scores_np)) + # Everything is derived from ``logits`` alone — the stored ``scores`` field + # is being retired, so masking must not depend on it. Recompute the model's + # full softmax (over every class) from the logits, then drop the excluded + # classes to exactly zero. An excluded class can never win argmax or carry + # probability. + shifted = np.asarray(logits, dtype=float) + shifted -= shifted.max() # stabilises exp without changing the softmax + full_softmax = np.exp(shifted) + full_softmax /= full_softmax.sum() # p over all classes; matches the retired ``scores`` + + kept = full_softmax.copy() + kept[excluded_indices] = 0.0 + kept_sum = kept.sum() # > 0: at least one class is kept (guaranteed upstream) + top_index = int(np.argmax(kept)) # over kept classes only (excluded are 0) if reweight: - # Renormalise: excluded classes become exactly 0; kept classes sum to 1. - new_scores = new_scores_np.tolist() - score = float(new_scores_np[top_index]) + # Renormalise: excluded classes stay 0; kept classes sum to 1. + new_scores_np = kept / kept_sum else: - # No renormalisation: kept classes keep their original absolute scores; - # excluded classes are zeroed. The winner is the same (argmax is - # computed on the renormalised distribution before this branch). - new_scores = list(scores) - for idx in excluded_indices: - new_scores[idx] = 0.0 - score = float(new_scores[top_index]) + # No renormalisation: kept classes keep their original absolute + # probability; excluded classes are zeroed. Winner is unchanged. + new_scores_np = kept + new_scores = new_scores_np.tolist() + score = float(new_scores_np[top_index]) # No-change short-circuit: if masking shifted no probability (the classes - # this taxa list drops carried ~zero score here), leave the row untouched. - if isinstance(scores, list) and np.allclose(scores, new_scores, atol=1e-9): + # this taxa list drops carried ~zero probability here), leave the row + # untouched. Compared against the unmasked softmax, not the stored scores. + if np.allclose(full_softmax, new_scores_np, atol=1e-9): task_logger.debug(f"Classification {classification.pk} unchanged by masking; skipping") else: top_taxon = index_to_taxon.get(top_index) # guaranteed in taxa_in_list (top_index is kept) @@ -228,21 +230,28 @@ class ClassMaskingTask(BasePostProcessingTask): name = "Class masking" config_schema = ClassMaskingConfig - def _get_or_create_masking_algorithm(self, source_algorithm: Algorithm, taxa_list: TaxaList) -> Algorithm: - """Get or create the output algorithm for this (source algorithm, taxa list). - - One masking algorithm per pair keeps provenance reproducible: re-running - the same mask reuses the same Algorithm row. Its category map is the - source map (indices still align with the masked score vector) and is - persisted — earlier code set it in memory only, so masked classifications - referenced a null map. + def _get_or_create_masking_algorithm( + self, source_algorithm: Algorithm, taxa_list: TaxaList, *, reweight: bool + ) -> Algorithm: + """Get or create the output algorithm for this (source algorithm, taxa list, reweight mode). + + One masking algorithm per (source algorithm, taxa list, reweight mode) + keeps provenance reproducible: re-running the same mask reuses the same + Algorithm row. The reweight mode is part of the identity because the two + modes persist different score semantics (renormalised vs original + absolute), so a masked classification's ``applied_to.algorithm`` can tell + them apart. Its category map is the source map (indices still align with + the masked score vector) and is persisted — earlier code set it in memory + only, so masked classifications referenced a null map. """ + mode = "reweighted" if reweight else "absolute" algorithm, created = Algorithm.objects.get_or_create( - key=f"{source_algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}", + key=f"{source_algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}_{mode}", defaults={ - "name": f"{source_algorithm.name} (filtered by taxa list {taxa_list.name})", + "name": f"{source_algorithm.name} (filtered by taxa list {taxa_list.name}, {mode} scores)", "description": ( - f"Classifications from {source_algorithm.name} re-scored against taxa list {taxa_list.name}" + f"Classifications from {source_algorithm.name} re-scored against taxa list " + f"{taxa_list.name} ({mode} scores)" ), "task_type": AlgorithmTaskType.CLASSIFICATION.value, "category_map": source_algorithm.category_map, @@ -300,7 +309,9 @@ def run(self) -> None: if not source_algorithm.category_map: raise ValueError(f"Algorithm '{source_algorithm.name}' has no category map; cannot mask classes.") - masking_algorithm = self._get_or_create_masking_algorithm(source_algorithm, taxa_list) + masking_algorithm = self._get_or_create_masking_algorithm( + source_algorithm, taxa_list, reweight=config.reweight + ) classifications, scope_desc = self._scoped_classifications(config, source_algorithm) self.logger.info(f"Applying class masking on {scope_desc} using taxa list {taxa_list.pk}") diff --git a/ami/ml/post_processing/tests/test_class_masking.py b/ami/ml/post_processing/tests/test_class_masking.py index 887cf71c9..b0578272c 100644 --- a/ami/ml/post_processing/tests/test_class_masking.py +++ b/ami/ml/post_processing/tests/test_class_masking.py @@ -238,9 +238,12 @@ def test_task_run_collection_scope_persists_masking_algorithm(self): algorithm_id=self.algorithm.pk, ).run() - # The per-(source algorithm, taxa list) masking algorithm exists and kept - # its category map (the bug being guarded: it used to be set in memory only). - masking_algo = Algorithm.objects.get(key=f"{self.algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}") + # The per-(source algorithm, taxa list, reweight mode) masking algorithm + # exists and kept its category map (the bug being guarded: it used to be + # set in memory only). Default reweight=True → the "reweighted" mode. + masking_algo = Algorithm.objects.get( + key=f"{self.algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}_reweighted" + ) self.assertIsNotNone(masking_algo.category_map_id) self.assertEqual(masking_algo.category_map_id, self.algorithm.category_map_id) @@ -250,6 +253,47 @@ def test_task_run_collection_scope_persists_masking_algorithm(self): occ.refresh_from_db() self.assertEqual(occ.determination, self.species_taxa[1], "Occurrence determination follows the masked result") + def test_reweight_modes_get_distinct_masking_algorithms(self): + """The reweight mode is part of the masking algorithm's identity. + + reweight=True and reweight=False persist different score semantics + (renormalised vs original absolute), so they must resolve to different + Algorithm rows — otherwise a masked classification's + ``applied_to.algorithm`` could not tell which mode produced it. Both keys + derive from the same (source algorithm, taxa list); only the mode suffix + differs. + """ + logits = [0.5, 3.0, 3.5] + taxa_list = TaxaList.objects.create(name="Reweight identity list") + taxa_list.taxa.set(self.species_taxa[:2]) + + det_t, _ = self._detection_with_occurrence() + self._create_classification_with_logits(det_t, self.species_taxa[2], _softmax(logits), logits) + ClassMaskingTask( + source_image_collection_id=self.collection.pk, + taxa_list_id=taxa_list.pk, + algorithm_id=self.algorithm.pk, + reweight=True, + ).run() + + det_f, _ = self._detection_with_occurrence() + self._create_classification_with_logits(det_f, self.species_taxa[2], _softmax(logits), logits) + ClassMaskingTask( + source_image_collection_id=self.collection.pk, + taxa_list_id=taxa_list.pk, + algorithm_id=self.algorithm.pk, + reweight=False, + ).run() + + base = f"{self.algorithm.key}_filtered_by_taxa_list_{taxa_list.pk}" + reweighted = Algorithm.objects.get(key=f"{base}_reweighted") + absolute = Algorithm.objects.get(key=f"{base}_absolute") + self.assertNotEqual(reweighted.pk, absolute.pk, "Each reweight mode gets its own masking algorithm") + + # Each detection's masked classification points at the algorithm for its mode. + self.assertTrue(Classification.objects.filter(detection=det_t, terminal=True, algorithm=reweighted).exists()) + self.assertTrue(Classification.objects.filter(detection=det_f, terminal=True, algorithm=absolute).exists()) + def test_task_run_occurrence_scope(self): logits = [2.0, 1.0, 5.0] taxa_list = TaxaList.objects.create(name="Occ scope list")