Skip to content
20 changes: 18 additions & 2 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from ami.jobs.models import Job
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.admin.actions import make_post_processing_action
from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm
from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm
from ami.ml.post_processing.class_masking import ClassMaskingTask
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
from ami.ml.tasks import remove_duplicate_classifications

Expand Down Expand Up @@ -552,6 +554,12 @@ def detections_count(self, obj) -> int:
scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk},
name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"),
)
run_class_masking = make_post_processing_action(
ClassMaskingTask,
ClassMaskingActionForm,
scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk},
name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"),
)

@admin.action(description="Recompute determination from current classifications and identifications")
def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any]) -> None:
Expand All @@ -568,7 +576,7 @@ def recompute_determination(self, request: HttpRequest, queryset: QuerySet[Any])
count += 1
self.message_user(request, f"Recomputed determination for {count} occurrence(s).")

actions = [run_small_size_filter, recompute_determination]
actions = [run_small_size_filter, run_class_masking, recompute_determination]

# Order by -id (the indexed primary key) rather than -created_at, which has no
# index and would force a full sort of the table to find the newest page. id
Expand Down Expand Up @@ -850,11 +858,19 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)

run_class_masking = make_post_processing_action(
ClassMaskingTask,
ClassMaskingActionForm,
scope_resolver=lambda collection: {"source_image_collection_id": collection.pk},
name_resolver=lambda task_cls, collection: (
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)
actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
run_class_masking,
]

# Hide images many-to-many field from form. This would list all source images in the database.
Expand Down
21 changes: 21 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,26 @@ class ClassificationPredictionItemSerializer(serializers.Serializer):
logit = serializers.FloatField(read_only=True)


class ClassificationAppliedToSerializer(serializers.ModelSerializer):
"""Lightweight nested representation of the parent classification this was derived from.

Post-processing tasks (class masking, rank rollup) record provenance via
``Classification.applied_to``; this exposes just enough to show what a result
was derived from without recursing back into the full classification.
"""

algorithm = AlgorithmSerializer(read_only=True)

class Meta:
model = Classification
fields = ["id", "created_at", "algorithm"]
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class ClassificationSerializer(DefaultSerializer):
taxon = TaxonNestedSerializer(read_only=True)
algorithm = AlgorithmSerializer(read_only=True)
top_n = ClassificationPredictionItemSerializer(many=True, read_only=True)
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
Expand All @@ -957,6 +973,7 @@ class Meta:
"scores",
"logits",
"top_n",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -979,6 +996,8 @@ class Meta(ClassificationSerializer.Meta):


class ClassificationListSerializer(DefaultSerializer):
applied_to = ClassificationAppliedToSerializer(read_only=True)

class Meta:
model = Classification
fields = [
Expand All @@ -987,6 +1006,7 @@ class Meta:
"taxon",
"score",
"algorithm",
"applied_to",
"created_at",
"updated_at",
]
Expand All @@ -1006,6 +1026,7 @@ class Meta:
"score",
"terminal",
"algorithm",
"applied_to",
"created_at",
]

Expand Down
2 changes: 1 addition & 1 deletion ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,7 @@ class ClassificationViewSet(DefaultViewSet, ProjectMixin):
"""

require_project_for_list = True # Unfiltered list scans are too expensive on this table
queryset = Classification.objects.all().select_related("taxon", "algorithm") # , "detection")
queryset = Classification.objects.all().select_related("taxon", "algorithm", "applied_to__algorithm")
serializer_class = ClassificationSerializer
filterset_fields = [
# Docs about slow loading API browser because of large choice fields
Expand Down
5 changes: 4 additions & 1 deletion ami/main/models_future/occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def _detections_prefetch(*, ordering: tuple[str, ...], with_source_image: bool)
qs = Detection.objects.prefetch_related(
Prefetch(
"classifications",
queryset=Classification.objects.select_related("taxon", "algorithm"),
# applied_to__algorithm: post-processed classifications (class masking,
# rank rollup) serialize their provenance parent; pull it here so the
# nested applied_to render doesn't issue a query per classification.
queryset=Classification.objects.select_related("taxon", "algorithm", "applied_to__algorithm"),
)
).order_by(*ordering)
if with_source_image:
Expand Down
84 changes: 84 additions & 0 deletions ami/ml/management/commands/run_class_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from django.core.management.base import BaseCommand, CommandError

from ami.main.models import SourceImageCollection, TaxaList
from ami.ml.models.algorithm import Algorithm
from ami.ml.post_processing.class_masking import ClassMaskingTask


class Command(BaseCommand):
help = (
"Run class masking post-processing on a source image collection. "
"Masks classifier logits for species not in the given taxa list and recalculates softmax scores."
)

def add_arguments(self, parser):
parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process")
parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask")
parser.add_argument(
"--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask"
)
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes")

def handle(self, *args, **options):
collection_id = options["collection_id"]
taxa_list_id = options["taxa_list_id"]
algorithm_id = options["algorithm_id"]
dry_run = options["dry_run"]

# Validate inputs
try:
collection = SourceImageCollection.objects.get(pk=collection_id)
except SourceImageCollection.DoesNotExist:
raise CommandError(f"SourceImageCollection {collection_id} does not exist.")

try:
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
except TaxaList.DoesNotExist:
raise CommandError(f"TaxaList {taxa_list_id} does not exist.")

try:
algorithm = Algorithm.objects.get(pk=algorithm_id)
except Algorithm.DoesNotExist:
raise CommandError(f"Algorithm {algorithm_id} does not exist.")

if not algorithm.category_map:
raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.")

from ami.main.models import Classification

classification_count = (
Classification.objects.filter(
detection__source_image__collections=collection,
terminal=True,
algorithm=algorithm,
scores__isnull=False,
logits__isnull=False,
)
.distinct()
.count()
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

taxa_count = taxa_list.taxa.count()

self.stdout.write(
f"Collection: {collection.name} (id={collection.pk})\n"
f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n"
f"Algorithm: {algorithm.name} (id={algorithm.pk})\n"
f"Classifications to process: {classification_count}"
)

if classification_count == 0:
raise CommandError("No terminal classifications with scores found for this collection/algorithm.")

if dry_run:
self.stdout.write(self.style.WARNING("Dry run — no changes made."))
return

self.stdout.write("Running class masking...")
task = ClassMaskingTask(
source_image_collection_id=collection_id,
taxa_list_id=taxa_list_id,
algorithm_id=algorithm_id,
)
task.run()
self.stdout.write(self.style.SUCCESS("Class masking completed."))
1 change: 1 addition & 0 deletions ami/ml/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import class_masking # noqa: F401
from . import small_size_filter # noqa: F401
6 changes: 4 additions & 2 deletions ami/ml/post_processing/admin/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,12 @@ def _render(form: BasePostProcessingActionForm) -> TemplateResponse:
)
return None

# Hand the form the selected rows so it can scope its fields to the
# selection (e.g. only offer algorithms that ran on the chosen occurrence).
if not request.POST.get("confirm"):
return _render(form_class())
return _render(form_class(scope_queryset=queryset))

form = form_class(request.POST)
form = form_class(request.POST, scope_queryset=queryset)
if not form.is_valid():
return _render(form)

Expand Down
76 changes: 76 additions & 0 deletions ami/ml/post_processing/admin/class_masking_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from django import forms

from ami.main.models import Occurrence, TaxaList
from ami.ml.models import Algorithm
from ami.ml.models.algorithm import AlgorithmTaskType
from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm


class ClassMaskingActionForm(BasePostProcessingActionForm):
"""Knobs surfaced when an admin triggers Class masking.

The operator picks the source classifier and the taxa list to keep; the
scope (which collection or occurrence) is supplied by the admin entry point,
not the form. Selections are model instances, so ``to_config`` hands the
schema their primary keys (``ClassMaskingConfig`` expects ``*_id`` ints).
"""

algorithm_id = forms.ModelChoiceField(
queryset=Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value).order_by("name"),
label="Source classifier",
help_text="The classification algorithm whose terminal predictions will be re-scored.",
)
taxa_list_id = forms.ModelChoiceField(
queryset=TaxaList.objects.all().order_by("name"),
label="Taxa list to keep",
help_text=(
"Classes whose taxon is not in this list are masked out; each "
"classification's softmax is renormalised over the classes that remain."
),
)
reweight = forms.BooleanField(
required=False,
initial=True,
label="Reweight (renormalise) scores",
help_text=(
"Renormalise the kept classes to sum to 1. "
"Off = keep the model's raw absolute scores; the chosen species is unchanged either way."
),
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Narrow the classifier dropdown to algorithms that actually produced
# classifications in the selected scope, so the operator cannot pick a
# classifier whose masking would be a no-op for the chosen rows. This is
# only done for an occurrence scope, where the lookup touches the handful
# of classifications under the picked occurrences. A collection scope
# keeps the full classifier list on purpose: the equivalent lookup is an
# unbounded DISTINCT over every classification in the collection (hundreds
# of thousands of rows on a large collection) and can time out while the
# form renders. An over-broad option is harmless — masking a classifier
# that produced nothing in scope changes nothing.
if self.scope_queryset is not None and self.scope_queryset.model is Occurrence:
self.fields["algorithm_id"].queryset = self._algorithms_for_scope(self.scope_queryset)

@staticmethod
def _algorithms_for_scope(scope_queryset):
"""Classification algorithms that produced classifications within the
selected occurrences."""
return (
Algorithm.objects.filter(
task_type=AlgorithmTaskType.CLASSIFICATION.value,
classifications__detection__occurrence__in=scope_queryset,
)
.distinct()
.order_by("name")
)

def to_config(self) -> dict:
return {
"algorithm_id": self.cleaned_data["algorithm_id"].pk,
"taxa_list_id": self.cleaned_data["taxa_list_id"].pk,
"reweight": self.cleaned_data["reweight"],
}
10 changes: 10 additions & 0 deletions ami/ml/post_processing/admin/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class BasePostProcessingActionForm(forms.Form):
optional fields, derive computed values, rename keys).
"""

def __init__(self, *args, scope_queryset=None, **kwargs):
"""Capture the admin selection the action will run on.

``scope_queryset`` is the queryset of rows the operator picked (e.g. the
chosen occurrences or collections). Subclasses may use it to constrain
their fields to that selection; forms that don't need it ignore it.
"""
self.scope_queryset = scope_queryset
super().__init__(*args, **kwargs)

def to_config(self) -> dict:
"""Return ``cleaned_data`` shaped for ``Job.params['config']``."""
return dict(self.cleaned_data)
10 changes: 8 additions & 2 deletions ami/ml/post_processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def update_progress(self, progress: float):

if self.job:
self.job.progress.update_stage(self.job.job_type_key, progress=progress)
self.job.save(update_fields=["progress"])
# Bump updated_at alongside progress: the stale-job reaper
# (check_stale_jobs) revokes running jobs whose updated_at is older
# than STALLED_JOBS_MAX_MINUTES. A long post-processing run that only
# touched "progress" would look frozen and be reaped mid-flight.
self.job.save(update_fields=["progress", "updated_at"])

else:
# No job object — fallback to plain logging
Expand All @@ -99,7 +103,9 @@ def report_stage_metrics(self, metrics: dict[str, Any]):
stage_key = self.job.job_type_key
for label, value in metrics.items():
self.job.progress.add_or_update_stage_param(stage_key, label, value)
self.job.save(update_fields=["progress"])
# Bump updated_at so the stale-job reaper sees an actively-progressing
# run; see update_progress for the full reasoning.
self.job.save(update_fields=["progress", "updated_at"])

@abc.abstractmethod
def run(self) -> None:
Expand Down
Loading
Loading