-
Notifications
You must be signed in to change notification settings - Fork 14
Let admins mask classifier predictions to a species list #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b26a313
feat(post-processing): implement class masking + rank rollup on the a…
mihow 2dbffab
docs(post-processing): note rank_rollup is not lineage-constrained
mihow ba796c1
feat(post-processing): scope the class-mask algorithm list to the sel…
mihow 7f97cd5
fix(post-processing): stop the class-mask form timing out on large co…
mihow c5aa3b6
refactor(post-processing): split rank roll-up into a follow-up PR
mihow 194efca
feat(post-processing): make class masking survive the Job path at sca…
mihow 3072dae
fix(post-processing): bump updated_at on progress saves so the reaper…
mihow 1afe75a
fix(post-processing): address class-masking review — logits-only scor…
mihow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| from django.core.management.base import BaseCommand, CommandError | ||
|
|
||
| from ami.main.models import SourceImageCollection, TaxaList | ||
| from ami.ml.models.algorithm import Algorithm | ||
| from ami.ml.post_processing.class_masking import ClassMaskingTask | ||
|
|
||
|
|
||
| class Command(BaseCommand): | ||
| help = ( | ||
| "Run class masking post-processing on a source image collection. " | ||
| "Masks classifier logits for species not in the given taxa list and recalculates softmax scores." | ||
| ) | ||
|
|
||
| def add_arguments(self, parser): | ||
| parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process") | ||
| parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask") | ||
| parser.add_argument( | ||
| "--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask" | ||
| ) | ||
| parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") | ||
|
|
||
| def handle(self, *args, **options): | ||
| collection_id = options["collection_id"] | ||
| taxa_list_id = options["taxa_list_id"] | ||
| algorithm_id = options["algorithm_id"] | ||
| dry_run = options["dry_run"] | ||
|
|
||
| # Validate inputs | ||
| try: | ||
| collection = SourceImageCollection.objects.get(pk=collection_id) | ||
| except SourceImageCollection.DoesNotExist: | ||
| raise CommandError(f"SourceImageCollection {collection_id} does not exist.") | ||
|
|
||
| try: | ||
| taxa_list = TaxaList.objects.get(pk=taxa_list_id) | ||
| except TaxaList.DoesNotExist: | ||
| raise CommandError(f"TaxaList {taxa_list_id} does not exist.") | ||
|
|
||
| try: | ||
| algorithm = Algorithm.objects.get(pk=algorithm_id) | ||
| except Algorithm.DoesNotExist: | ||
| raise CommandError(f"Algorithm {algorithm_id} does not exist.") | ||
|
|
||
| if not algorithm.category_map: | ||
| raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.") | ||
|
|
||
| from ami.main.models import Classification | ||
|
|
||
| classification_count = ( | ||
| Classification.objects.filter( | ||
| detection__source_image__collections=collection, | ||
| terminal=True, | ||
| algorithm=algorithm, | ||
| scores__isnull=False, | ||
| logits__isnull=False, | ||
| ) | ||
| .distinct() | ||
| .count() | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| taxa_count = taxa_list.taxa.count() | ||
|
|
||
| self.stdout.write( | ||
| f"Collection: {collection.name} (id={collection.pk})\n" | ||
| f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n" | ||
| f"Algorithm: {algorithm.name} (id={algorithm.pk})\n" | ||
| f"Classifications to process: {classification_count}" | ||
| ) | ||
|
|
||
| if classification_count == 0: | ||
| raise CommandError("No terminal classifications with scores found for this collection/algorithm.") | ||
|
|
||
| if dry_run: | ||
| self.stdout.write(self.style.WARNING("Dry run — no changes made.")) | ||
| return | ||
|
|
||
| self.stdout.write("Running class masking...") | ||
| task = ClassMaskingTask( | ||
| source_image_collection_id=collection_id, | ||
| taxa_list_id=taxa_list_id, | ||
| algorithm_id=algorithm_id, | ||
| ) | ||
| task.run() | ||
| self.stdout.write(self.style.SUCCESS("Class masking completed.")) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from . import class_masking # noqa: F401 | ||
| from . import small_size_filter # noqa: F401 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from django import forms | ||
|
|
||
| from ami.main.models import Occurrence, TaxaList | ||
| from ami.ml.models import Algorithm | ||
| from ami.ml.models.algorithm import AlgorithmTaskType | ||
| from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm | ||
|
|
||
|
|
||
| class ClassMaskingActionForm(BasePostProcessingActionForm): | ||
| """Knobs surfaced when an admin triggers Class masking. | ||
|
|
||
| The operator picks the source classifier and the taxa list to keep; the | ||
| scope (which collection or occurrence) is supplied by the admin entry point, | ||
| not the form. Selections are model instances, so ``to_config`` hands the | ||
| schema their primary keys (``ClassMaskingConfig`` expects ``*_id`` ints). | ||
| """ | ||
|
|
||
| algorithm_id = forms.ModelChoiceField( | ||
| queryset=Algorithm.objects.filter(task_type=AlgorithmTaskType.CLASSIFICATION.value).order_by("name"), | ||
| label="Source classifier", | ||
| help_text="The classification algorithm whose terminal predictions will be re-scored.", | ||
| ) | ||
| taxa_list_id = forms.ModelChoiceField( | ||
| queryset=TaxaList.objects.all().order_by("name"), | ||
| label="Taxa list to keep", | ||
| help_text=( | ||
| "Classes whose taxon is not in this list are masked out; each " | ||
| "classification's softmax is renormalised over the classes that remain." | ||
| ), | ||
| ) | ||
| reweight = forms.BooleanField( | ||
| required=False, | ||
| initial=True, | ||
| label="Reweight (renormalise) scores", | ||
| help_text=( | ||
| "Renormalise the kept classes to sum to 1. " | ||
| "Off = keep the model's raw absolute scores; the chosen species is unchanged either way." | ||
| ), | ||
| ) | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| # Narrow the classifier dropdown to algorithms that actually produced | ||
| # classifications in the selected scope, so the operator cannot pick a | ||
| # classifier whose masking would be a no-op for the chosen rows. This is | ||
| # only done for an occurrence scope, where the lookup touches the handful | ||
| # of classifications under the picked occurrences. A collection scope | ||
| # keeps the full classifier list on purpose: the equivalent lookup is an | ||
| # unbounded DISTINCT over every classification in the collection (hundreds | ||
| # of thousands of rows on a large collection) and can time out while the | ||
| # form renders. An over-broad option is harmless — masking a classifier | ||
| # that produced nothing in scope changes nothing. | ||
| if self.scope_queryset is not None and self.scope_queryset.model is Occurrence: | ||
| self.fields["algorithm_id"].queryset = self._algorithms_for_scope(self.scope_queryset) | ||
|
|
||
| @staticmethod | ||
| def _algorithms_for_scope(scope_queryset): | ||
| """Classification algorithms that produced classifications within the | ||
| selected occurrences.""" | ||
| return ( | ||
| Algorithm.objects.filter( | ||
| task_type=AlgorithmTaskType.CLASSIFICATION.value, | ||
| classifications__detection__occurrence__in=scope_queryset, | ||
| ) | ||
| .distinct() | ||
| .order_by("name") | ||
| ) | ||
|
|
||
| def to_config(self) -> dict: | ||
| return { | ||
| "algorithm_id": self.cleaned_data["algorithm_id"].pk, | ||
| "taxa_list_id": self.cleaned_data["taxa_list_id"].pk, | ||
| "reweight": self.cleaned_data["reweight"], | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.