diff --git a/ami/main/admin.py b/ami/main/admin.py index b25860507..8abe81f93 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -1,7 +1,7 @@ import datetime from typing import Any -from django.contrib import admin +from django.contrib import admin, messages from django.db import models from django.db.models.query import QuerySet from django.http.request import HttpRequest @@ -12,6 +12,7 @@ import ami.utils from ami import tasks from ami.jobs.models import Job +from ami.main.tasks import generate_regional_taxa_list_task from ami.ml.models.project_pipeline_config import ProjectPipelineConfig from ami.ml.post_processing.admin.actions import make_post_processing_action from ami.ml.post_processing.admin.class_masking_form import ClassMaskingActionForm @@ -108,6 +109,7 @@ def save_related(self, request, form, formsets, change): inlines = [ProjectPipelineConfigInline] autocomplete_fields = ("owner", "default_filters_include_taxa", "default_filters_exclude_taxa") + raw_id_fields = ("default_taxa_list",) def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: return super().get_queryset(request).select_related("owner") @@ -138,6 +140,20 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: ), }, ), + ( + "Region & Taxa List", + { + "fields": ( + "region_source", + "region_code", + "default_taxa_list", + ), + "description": ( + "The region a taxa list can be generated from, and the list used as this " + "project's fallback when a site has none configured." + ), + }, + ), ( "Ownership & Access", { @@ -155,7 +171,33 @@ def _remove_duplicate_classifications(self, request: HttpRequest, queryset: Quer task_ids.append(task.id) self.message_user(request, f"Started {len(task_ids)} tasks to delete classification: {task_ids}") - actions = [_remove_duplicate_classifications] + @admin.action(description="Generate a regional taxa list from the configured region") + def generate_regional_taxa_list_action(self, request: HttpRequest, queryset: QuerySet[Project]) -> None: + """Enqueue regional taxa-list generation for each selected project that has a + region configured; the list is attached to the project's default_taxa_list. + Runs in the background because the external fetch is slow.""" + enqueued = 0 + skipped = [] + for project in queryset: + if not project.region_source or not project.region_code: + skipped.append(project.name) + continue + generate_regional_taxa_list_task.delay( + project_id=project.pk, + region_source=project.region_source, + region_code=project.region_code, + ) + enqueued += 1 + if enqueued: + self.message_user(request, f"Queued regional taxa-list generation for {enqueued} project(s).") + if skipped: + self.message_user( + request, + f"Skipped (no region configured): {', '.join(skipped)}", + level=messages.WARNING, + ) + + actions = [_remove_duplicate_classifications, generate_regional_taxa_list_action] @admin.register(Deployment) @@ -763,7 +805,40 @@ class DeviceAdmin(admin.ModelAdmin[Device]): @admin.register(Site) class SiteAdmin(admin.ModelAdmin[Site]): - """Admin panel example for ``Site`` model.""" + """Admin panel for ``Site`` (Research Site) model.""" + + list_display = ("name", "project", "region_source", "region_code", "taxa_list") + list_filter = ("region_source",) + search_fields = ("name",) + raw_id_fields = ("project", "taxa_list") + + @admin.action(description="Generate a regional taxa list from the configured region") + def generate_regional_taxa_list_action(self, request: HttpRequest, queryset: QuerySet[Site]) -> None: + """Enqueue regional taxa-list generation for each selected site that has a + project and a region configured; the list is attached to the site's taxa_list.""" + enqueued = 0 + skipped = [] + for site in queryset: + if not site.project_id or not site.region_source or not site.region_code: + skipped.append(site.name) + continue + generate_regional_taxa_list_task.delay( + project_id=site.project_id, + region_source=site.region_source, + region_code=site.region_code, + site_id=site.pk, + ) + enqueued += 1 + if enqueued: + self.message_user(request, f"Queued regional taxa-list generation for {enqueued} site(s).") + if skipped: + self.message_user( + request, + f"Skipped (missing project or region): {', '.join(skipped)}", + level=messages.WARNING, + ) + + actions = [generate_regional_taxa_list_action] @admin.register(S3StorageSource) diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 591a4a000..30d3d3391 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -35,6 +35,8 @@ from ami.main.api.schemas import limit_doc_param, project_id_doc_param from ami.main.api.serializers import TagSerializer from ami.main.models_future.occurrence import model_agreement_for_project, top_identifiers_for_project +from ami.main.services import regional_taxa +from ami.main.tasks import generate_regional_taxa_list_task from ami.utils.requests import get_default_classification_threshold from ami.utils.storages import ConnectionTestResult @@ -49,6 +51,7 @@ Page, Project, ProjectQuerySet, + RegionSource, S3StorageSource, Site, SourceImage, @@ -155,6 +158,15 @@ def get_count(self, queryset): return super().get_count(queryset.order_by().values("pk")) +class RegionalTaxaListRequestSerializer(serializers.Serializer): + """Body for POST /projects/{id}/generate-regional-taxa-list/ (see #1364).""" + + region_source = serializers.ChoiceField(choices=RegionSource.choices, default=RegionSource.GBIF_GADM.value) + region_code = serializers.CharField(required=False, allow_blank=True) + classifier_id = serializers.IntegerField(required=False) + include_uncovered = serializers.BooleanField(required=False, default=False) + + class ProjectViewSet(DefaultViewSet, ProjectMixin): """ API endpoint that allows projects to be viewed or edited. @@ -256,6 +268,59 @@ def perform_create(self, serializer): # Add current user as project owner serializer.save(owner=self.request.user) + @action( + detail=True, + methods=["post"], + name="generate-regional-taxa-list", + url_path="generate-regional-taxa-list", + ) + def generate_regional_taxa_list(self, request, pk=None) -> Response: + """Queue generation of a taxa list for this project from a geographic region. + + The external biodiversity-database fetch is slow, so this enqueues a background + task and returns 202; on success the generated list becomes the project's + default_taxa_list. When region_code is omitted it is derived from the project's + deployments. Requires update permission on the project. See issue #1364. + """ + project = get_object_or_404(self.get_queryset(), pk=pk) + if not request.user.has_perm("update_project", project): + raise PermissionDenied("You do not have permission to modify this project.") + + params = RegionalTaxaListRequestSerializer(data=request.data) + params.is_valid(raise_exception=True) + data = params.validated_data + + region_source = data["region_source"] + region_code = (data.get("region_code") or "").strip() + if not region_code: + derived = regional_taxa.derive_region_for_project(project, region_source=region_source) + if derived is None: + raise api_exceptions.ValidationError( + { + "region_code": ( + "No region_code was provided and none could be derived from the " "project's deployments." + ) + } + ) + _source, region_code = derived + + generate_regional_taxa_list_task.delay( + project_id=project.pk, + region_source=region_source, + region_code=region_code, + classifier_id=data.get("classifier_id"), + include_uncovered=data.get("include_uncovered", False), + ) + return Response( + { + "project_id": project.pk, + "region_source": region_source, + "region_code": region_code, + "status": "queued", + }, + status=status.HTTP_202_ACCEPTED, + ) + @action(detail=True, methods=["get"], name="charts") def charts(self, request, pk=None): """ diff --git a/ami/main/management/commands/generate_regional_taxa_list.py b/ami/main/management/commands/generate_regional_taxa_list.py new file mode 100644 index 000000000..e9871c277 --- /dev/null +++ b/ami/main/management/commands/generate_regional_taxa_list.py @@ -0,0 +1,119 @@ +"""Generate a project taxa list from a geographic region (issue #1364). + +A thin wrapper over `ami.main.services.regional_taxa.generate_regional_taxa_list`. Run +it for one project with an explicit region, or use `--all-projects` to backfill every +project, deriving each one's region from a representative deployment's coordinates +(GBIF reverse-geocode). The heavy lifting lives in the service so the same behavior is +reachable from the admin, the API, and tests; this command is the operator/backfill +entry point. +""" + +from __future__ import annotations + +import logging + +from django.core.management.base import BaseCommand, CommandError + +from ami.main.models import Project, RegionSource +from ami.main.services import regional_taxa + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Generate a project taxa list from a geographic region (GBIF/GADM)." + + def add_arguments(self, parser): + parser.add_argument("--project", type=int, help="Project id to attach the list to.") + parser.add_argument( + "--all-projects", + action="store_true", + help="Backfill every project, deriving each region from its deployments.", + ) + parser.add_argument( + "--region-source", + default=RegionSource.GBIF_GADM.value, + choices=[choice.value for choice in RegionSource], + ) + parser.add_argument( + "--region-code", + help="Region id (a GADM gid such as USA.46_1). Omit with --all-projects; it is derived.", + ) + parser.add_argument("--classifier", type=int, help="Algorithm id for a reporting-only coverage overlay.") + parser.add_argument("--name", help="TaxaList name (defaults to the region code).") + parser.add_argument( + "--include-uncovered", + action="store_true", + help="Also keep regional species no model can predict (each flagged as uncovered).", + ) + parser.add_argument( + "--no-create-missing", + action="store_true", + help="Do not create Taxon rows for regional species absent from the database.", + ) + parser.add_argument("--dry-run", action="store_true", help="Report the counts without writing anything.") + + def handle(self, *args, **options): + classifier = self._resolve_classifier(options.get("classifier")) + common = dict( + region_source=options["region_source"], + classifier=classifier, + include_uncovered=options["include_uncovered"], + create_missing=not options["no_create_missing"], + name=options["name"], + dry_run=options["dry_run"], + ) + + if options["all_projects"]: + if options["region_code"]: + raise CommandError("--region-code is derived per project with --all-projects; do not pass it.") + self._run_all_projects(common) + return + + if not options["region_code"]: + raise CommandError("--region-code is required unless --all-projects is used.") + project = self._resolve_project(options.get("project")) + result = regional_taxa.generate_regional_taxa_list( + project=project, region_code=options["region_code"], **common + ) + self._report(project, result) + + def _run_all_projects(self, common: dict) -> None: + for project in Project.objects.all().order_by("pk"): + derived = regional_taxa.derive_region_for_project(project, region_source=common["region_source"]) + if derived is None: + self.stdout.write(f"[skip] project {project.pk} {project.name!r}: no region could be derived") + continue + _source, region_code = derived + result = regional_taxa.generate_regional_taxa_list(project=project, region_code=region_code, **common) + self._report(project, result) + + def _resolve_project(self, project_id: int | None) -> Project | None: + if not project_id: + return None + try: + return Project.objects.get(pk=project_id) + except Project.DoesNotExist: + raise CommandError(f"Project {project_id} does not exist.") + + def _resolve_classifier(self, classifier_id: int | None): + if not classifier_id: + return None + from ami.ml.models.algorithm import Algorithm + + try: + return Algorithm.objects.get(pk=classifier_id) + except Algorithm.DoesNotExist: + raise CommandError(f"Algorithm {classifier_id} does not exist.") + + def _report(self, project: Project | None, result) -> None: + scope = f"project {project.pk} {project.name!r}" if project else "global" + suffix = " [dry-run]" if result.dry_run else "" + self.stdout.write( + self.style.SUCCESS( + f"[{scope}] region={result.region_code} saved={result.saved_list_size} " + f"(covered={result.model_covered}, uncovered={result.regional_no_model_coverage}, " + f"created={result.created_taxa}, in_db={result.already_in_db}, " + f"regional_total={result.regional_total}){suffix}" + ) + ) diff --git a/ami/main/management/commands/import_taxa.py b/ami/main/management/commands/import_taxa.py index a1f9cf49b..2fb3250cb 100644 --- a/ami/main/management/commands/import_taxa.py +++ b/ami/main/management/commands/import_taxa.py @@ -13,9 +13,8 @@ # import progress bar from tqdm import tqdm -from ...models import TaxaList, Taxon, TaxonRank - -RANK_CHOICES = [rank for rank in TaxonRank] +from ...models import TaxaList, Taxon +from ...services.taxonomy import create_taxon, get_or_create_root_taxon logger = logging.getLogger(__name__) # Set level @@ -140,25 +139,6 @@ def fix_values(taxon_data: dict) -> dict: return taxon_data -def get_or_create_root_taxon() -> Taxon: - """ - Important! This is where the root taxon is configured. - """ - root_taxon_parent, created = Taxon.objects.get_or_create( - name="Arthropoda", rank=TaxonRank.PHYLUM.name, defaults={"ordering": 0} - ) - if created: - logger.info(f"Created root taxon {root_taxon_parent}") - else: - logger.info(f"Found existing root taxon {root_taxon_parent}") - if root_taxon_parent.parent: - # If the root taxon has a parent, remove it - # Otherwise, the root taxon will not be the root and there will be recursion issues - root_taxon_parent.parent = None - root_taxon_parent.save() - return root_taxon_parent - - class Command(BaseCommand): r""" Import taxa from a JSON file. Assign their rank, parent taxa, gbif_taxon_key, and accepted_name. @@ -268,7 +248,7 @@ def handle(self, *args, **options): taxon_data = fix_values(taxon_data) logger.debug(f"Parsed taxon data: {taxon_data}") if taxon_data: - created_taxa, updated_taxa, specific_taxon = self.create_taxon(taxon_data, root_taxon_parent) + created_taxa, updated_taxa, specific_taxon = create_taxon(taxon_data, root_taxon_parent) taxa_to_refresh.update(created_taxa) taxa_to_refresh.update(updated_taxa) taxalist.taxa.add(specific_taxon) @@ -297,145 +277,3 @@ def handle(self, *args, **options): logger.info("Updating cached values for all new or updated taxa") for taxon in tqdm(taxa_to_refresh): taxon.save(update_calculated_fields=True) - - def create_taxon(self, taxon_data: dict, root_taxon_parent: Taxon) -> tuple[set[Taxon], set[Taxon], Taxon]: - taxa_in_row = [] - created_taxa = set() - updated_taxa = set() - - # parent_must_match = ["SPECIES"]#], "SUBSPECIES", "VARIETY", "FORM"] - global parent_taxon - parent_taxon = root_taxon_parent - - for i, rank in enumerate(sorted(RANK_CHOICES)): - logger.debug(f"Checking rank {rank} {i} of {len(RANK_CHOICES)}") - logger.debug(f"Current parent taxon: {parent_taxon}") - # Create all parents and parents of parents - # Assume ranks are in order of rank - if rank.name.lower() in taxon_data.keys() and taxon_data[rank.name.lower()]: - name = taxon_data[rank.name.lower()] - gbif_taxon_key = taxon_data.get("gbif_taxon_key", None) - rank = rank.name.upper() - logger.debug(f"Taxon found in incoming row {i}: {rank} {name} (GBIF: {gbif_taxon_key})") - - # Look up existing taxon by name only, since names must be unique. - # If the taxon already exists, use it and maybe update it - taxon, created = Taxon.objects.get_or_create( - name=name, - defaults=dict( - rank=rank, - gbif_taxon_key=gbif_taxon_key, - parent=parent_taxon, - ), - ) - taxa_in_row.append(taxon) - - if created: - logger.debug(f"Created new taxon #{taxon.id} {taxon} ({taxon.rank})") - created_taxa.add(taxon) - else: - logger.debug(f"Using existing taxon #{taxon.id} {taxon} ({taxon.rank})") - - # Add or update the rank of the taxon based on incoming data - if not taxon.rank or taxon.rank != rank: - if not created: - logger.warning(f"Rank of existing {taxon} is changing from {taxon.rank} to {rank}") - taxon.rank = rank - taxon.save(update_calculated_fields=False) - if not created: - updated_taxa.add(taxon) - - # Add or update the parent of the taxon based on incoming data - # if the incoming parent is more specific than the existing parent - # (e.g. if the existing parent is Lepidoptera and the existing parent is a family) - if not taxon.parent or parent_taxon.get_rank() > taxon.parent.get_rank(): - parent = parent_taxon or root_taxon_parent - if parent == taxon: - logger.debug(f"Parent of {taxon} is itself, changing to (or keeping as) None") - parent = None - if taxon.parent != parent: - if not created: - logger.warn(f"Changing parent of {taxon} from {taxon.parent} to more specific {parent}") - taxon.parent = parent - taxon.save(update_calculated_fields=False) - if not created: - updated_taxa.add(taxon) - - parent_taxon = taxon - logger.debug(f"Next parent taxon: {parent_taxon.rank} {parent_taxon}") - else: - logger.debug(f"Did not find {rank} in incoming row, checking next rank") - - accepted_name = taxon_data.get("synonym_of", None) - - if not taxa_in_row: - raise ValueError(f"Could not find any ranks in {taxon_data}") - - # Make sure incoming taxa are sorted by rank - taxa_in_row = sorted(taxa_in_row, key=lambda taxon: taxon.get_rank()) - - logger.debug(f"Found {len(taxa_in_row)} taxa in row: {taxa_in_row}") - - specific_taxon = taxa_in_row[-1] - expected_specific_taxon_ranks = TaxonRank.SPECIES, TaxonRank.GENUS - if specific_taxon.get_rank() not in expected_specific_taxon_ranks: - logger.warn(f"Assumming the most specific taxon of this row is: {specific_taxon} {specific_taxon.rank}") - - specific_taxon_columns = [ - "author", - "authorship_date", - "gbif_taxon_key", - "bold_taxon_bin", - "inat_taxon_id", - "common_name_en", - "notes", - "sort_phylogeny", - "fieldguide_id", - "cover_image_url", - "cover_image_credit", - ] - - is_new = specific_taxon in created_taxa - needs_update = False - for column in specific_taxon_columns: - if column in taxon_data: - existing_value = getattr(specific_taxon, column) - incoming_value = taxon_data[column] - if existing_value != incoming_value: - if incoming_value is None: - # Don't overwrite existing values with None. - # This could potentially be a command line option to allow users to clear values. - logger.debug(f"Not changing {column} of {specific_taxon} from {existing_value} to None") - continue - if not is_new: - logger.info( - f"Changing {column} of {specific_taxon} to from {existing_value} to {incoming_value}" - ) - setattr(specific_taxon, column, taxon_data[column]) - needs_update = True - if needs_update: - specific_taxon.save(update_calculated_fields=False) - if not is_new: - # raise ValueError(f"TAXON DATA CHANGED for {specific_taxon}") - logger.warning(f"TAXON DATA CHANGED for existing {specific_taxon} ({specific_taxon.id})") - updated_taxa.add(specific_taxon) - - if accepted_name: - accepted_taxon, created = Taxon.objects.get_or_create( - name=accepted_name, - rank=specific_taxon.rank, - defaults={"parent": parent_taxon}, - ) - if created: - logger.info(f"Created accepted taxon {accepted_taxon}") - created_taxa.add(accepted_taxon) - - if specific_taxon.synonym_of != accepted_taxon: - logger.info(f"Setting synonym_of of {specific_taxon} to {accepted_taxon}") - specific_taxon.synonym_of = accepted_taxon - specific_taxon.save() - updated_taxa.add(specific_taxon) - - # - - return created_taxa, updated_taxa, specific_taxon diff --git a/ami/main/management/commands/refresh_taxon_model_coverage.py b/ami/main/management/commands/refresh_taxon_model_coverage.py new file mode 100644 index 000000000..a0ae481c9 --- /dev/null +++ b/ami/main/management/commands/refresh_taxon_model_coverage.py @@ -0,0 +1,33 @@ +import logging + +from django.core.management.base import BaseCommand, CommandError + +from ami.main.services.taxon_coverage import refresh_all_algorithm_coverage + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + """ + Rebuild Taxon.covered_by_algorithms / Taxon.has_model_coverage for every + algorithm with a category map. + + This is a full repair tool: the persisted model-coverage relationship is + normally kept fresh automatically (Algorithm.save() refreshes it whenever an + algorithm's category map changes), so this command is for the initial backfill + after the fields were added, or to repair drift from a write path that bypassed + the hook (e.g. a bulk_update on Algorithm.category_map). + + **Usage:** + python manage.py refresh_taxon_model_coverage + """ + + help = "Rebuild the Taxon <-> Algorithm model-coverage relationship for every algorithm with a category map." + + def handle(self, *args, **options): + try: + algorithms_processed = refresh_all_algorithm_coverage() + except Exception as e: + raise CommandError(f"Failed to refresh taxon model coverage: {e}") from e + + self.stdout.write(self.style.SUCCESS(f"Refreshed model coverage for {algorithms_processed} algorithm(s).")) diff --git a/ami/main/migrations/0095_regional_taxa_lists.py b/ami/main/migrations/0095_regional_taxa_lists.py new file mode 100644 index 000000000..0810ede67 --- /dev/null +++ b/ami/main/migrations/0095_regional_taxa_lists.py @@ -0,0 +1,93 @@ +# Generated by Django 4.2.10 on 2026-07-02 17:57 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0028_normalize_empty_endpoint_url_to_null"), + ("main", "0094_enable_async_pipeline_workers"), + ] + + operations = [ + migrations.AddField( + model_name="project", + name="default_taxa_list", + field=models.ForeignKey( + blank=True, + help_text="Fall-back taxa list for occurrences whose research site has none of its own.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="+", + to="main.taxalist", + ), + ), + migrations.AddField( + model_name="project", + name="region_code", + field=models.CharField( + blank=True, + help_text="A region identifier for region_source (a GADM gid or an iNaturalist place id).", + max_length=64, + ), + ), + migrations.AddField( + model_name="project", + name="region_source", + field=models.CharField( + blank=True, + choices=[("gbif_gadm", "GBIF (GADM region)"), ("inat_place", "iNaturalist (place)")], + max_length=32, + ), + ), + migrations.AddField( + model_name="site", + name="region_code", + field=models.CharField( + blank=True, + help_text="A region identifier for region_source (a GADM gid or an iNaturalist place id).", + max_length=64, + ), + ), + migrations.AddField( + model_name="site", + name="region_source", + field=models.CharField( + blank=True, + choices=[("gbif_gadm", "GBIF (GADM region)"), ("inat_place", "iNaturalist (place)")], + max_length=32, + ), + ), + migrations.AddField( + model_name="site", + name="taxa_list", + field=models.ForeignKey( + blank=True, + help_text="Taxa list to use for occurrences at this research site (e.g. for class masking).", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="+", + to="main.taxalist", + ), + ), + migrations.AddField( + model_name="taxon", + name="covered_by_algorithms", + field=models.ManyToManyField( + blank=True, + help_text="Algorithm(s) whose category map includes this taxon as a label.", + related_name="covered_taxa", + to="ml.algorithm", + ), + ), + migrations.AddField( + model_name="taxon", + name="has_model_coverage", + field=models.BooleanField( + db_index=True, + default=False, + help_text="True iff covered_by_algorithms is non-empty. Denormalized for cheap filtering.", + ), + ), + ] diff --git a/ami/main/models.py b/ami/main/models.py index 3662b4107..7e0041c01 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -98,6 +98,16 @@ class TaxonRank(OrderedEnum): ) +class RegionSource(models.TextChoices): + """Where a Site's or Project's region code comes from, so the regional taxa-list + service (`ami.main.services.regional_taxa`) knows which external API a stored + `region_code` refers to (a GBIF/GADM region id vs. an iNaturalist place id). + """ + + GBIF_GADM = "gbif_gadm", "GBIF (GADM region)" + INAT_PLACE = "inat_place", "iNaturalist (place)" + + def bbox_is_null(bbox) -> bool: """In-memory equivalent of null_detections_q() for an already-fetched bbox value.""" return bbox is None @@ -313,6 +323,23 @@ class Project(ProjectSettingsMixin, BaseModel): active = models.BooleanField(default=True) priority = models.IntegerField(default=1) + # Fall-back region used to generate a taxa list when a deployment's research site + # has no region of its own. See ami.main.services.regional_taxa. + region_source = models.CharField(max_length=32, choices=RegionSource.choices, blank=True) + region_code = models.CharField( + max_length=64, + blank=True, + help_text="A region identifier for region_source (a GADM gid or an iNaturalist place id).", + ) + default_taxa_list = models.ForeignKey( + "TaxaList", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="+", + help_text="Fall-back taxa list for occurrences whose research site has none of its own.", + ) + # Backreferences for type hinting captures: models.QuerySet["SourceImage"] deployments: models.QuerySet["Deployment"] @@ -658,6 +685,23 @@ class Site(BaseModel): description = models.TextField(blank=True) project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="sites") + # The region this site sits in, used to auto-generate a taxa list. See + # ami.main.services.regional_taxa. + region_source = models.CharField(max_length=32, choices=RegionSource.choices, blank=True) + region_code = models.CharField( + max_length=64, + blank=True, + help_text="A region identifier for region_source (a GADM gid or an iNaturalist place id).", + ) + taxa_list = models.ForeignKey( + "TaxaList", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="+", + help_text="Taxa list to use for occurrences at this research site (e.g. for class masking).", + ) + deployments: models.QuerySet["Deployment"] def deployments_count(self) -> int: @@ -4378,6 +4422,22 @@ class Taxon(BaseModel): classifications: models.QuerySet["Classification"] lists: models.QuerySet["TaxaList"] + # Which classifier(s), if any, can predict this taxon — i.e. whose category map + # lists the taxon's name as a label (the same Taxon.name == label join class + # masking uses via AlgorithmCategoryMap.with_taxa()). Derived, not user-editable: + # kept in sync by ami.main.services.taxon_coverage, not recomputed on every read. + covered_by_algorithms = models.ManyToManyField( + "ml.Algorithm", + related_name="covered_taxa", + blank=True, + help_text="Algorithm(s) whose category map includes this taxon as a label.", + ) + has_model_coverage = models.BooleanField( + default=False, + db_index=True, + help_text="True iff covered_by_algorithms is non-empty. Denormalized for cheap filtering.", + ) + author = models.CharField(max_length=255, blank=True) authorship_date = models.DateField(null=True, blank=True, help_text="The date the taxon was described.") ordering = models.IntegerField(null=True, blank=True) diff --git a/ami/main/services/__init__.py b/ami/main/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/main/services/gbif.py b/ami/main/services/gbif.py new file mode 100644 index 000000000..a8f6aee85 --- /dev/null +++ b/ami/main/services/gbif.py @@ -0,0 +1,139 @@ +"""GBIF-backed RegionalSpeciesSource: species recorded in a GADM region, via GBIF's +occurrence search faceted by species. + +Endpoints and parameters were exercised live against the GBIF API in the #1364 Phase 0 +spike (region: Vermont, scope: Lepidoptera) — see docs/claude/analysis in the planning +branch for the findings. Unit tests for this module stub `create_session()`; nothing +here is exercised against the network in CI. +""" + +from __future__ import annotations + +from ...utils.requests import create_session +from ..models import RegionSource +from .regional_taxa import SourceSpecies, TaxonScope + +GBIF_API_BASE = "https://api.gbif.org/v1" + +# Species-key facets are paginated; a hard cap keeps a pathological region (or a +# scope too broad for faceting) from looping forever. Not a product requirement — +# just a safety net. +DEFAULT_MAX_SPECIES = 5000 +DEFAULT_FACET_PAGE_SIZE = 1000 +DEFAULT_TIMEOUT_SECONDS = 60 + + +class GBIFRegionalSource: + source_key = RegionSource.GBIF_GADM.value + + def __init__( + self, + facet_page_size: int = DEFAULT_FACET_PAGE_SIZE, + max_species: int = DEFAULT_MAX_SPECIES, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + ): + self.facet_page_size = facet_page_size + self.max_species = max_species + self.timeout = timeout + + def fetch_species(self, region_code: str, taxon_scope: TaxonScope) -> list[SourceSpecies]: + if taxon_scope.gbif_taxon_key is None: + raise ValueError(f"GBIFRegionalSource requires a gbif_taxon_key on the taxon scope {taxon_scope.label!r}") + + session = create_session() + counts_by_key = self._fetch_species_counts(session, region_code, taxon_scope.gbif_taxon_key) + return self._resolve_species(session, counts_by_key) + + def _fetch_species_counts(self, session, region_code: str, gbif_taxon_key: int) -> dict[int, int]: + """Page through the speciesKey facet for the region, returning + {speciesKey: occurrence count}. Terminates when a page comes back shorter + than requested or the species cap is hit.""" + counts_by_key: dict[int, int] = {} + offset = 0 + while True: + response = session.get( + f"{GBIF_API_BASE}/occurrence/search", + params={ + "taxonKey": gbif_taxon_key, + "gadmGid": region_code, + "hasCoordinate": "true", + "facet": "speciesKey", + "facetLimit": self.facet_page_size, + "facetOffset": offset, + "limit": 0, + }, + timeout=self.timeout, + ) + response.raise_for_status() + facets = response.json().get("facets", []) + counts = facets[0]["counts"] if facets else [] + if not counts: + break + for entry in counts: + counts_by_key[int(entry["name"])] = entry.get("count", 0) + offset += self.facet_page_size + if len(counts) < self.facet_page_size or len(counts_by_key) >= self.max_species: + break + return counts_by_key + + def _resolve_species(self, session, counts_by_key: dict[int, int]) -> list[SourceSpecies]: + """Resolve each speciesKey to a scientific name (and, when available, its + rank and rank-hierarchy fields) via GBIF's species-by-key endpoint.""" + species: list[SourceSpecies] = [] + for key, count in counts_by_key.items(): + response = session.get(f"{GBIF_API_BASE}/species/{key}", timeout=self.timeout) + if response.status_code == 404: + continue + response.raise_for_status() + data = response.json() + name = (data.get("canonicalName") or data.get("species") or data.get("scientificName") or "").strip() + if not name: + continue + species.append( + SourceSpecies( + source=self.source_key, + scientific_name=name, + rank=data.get("rank"), + gbif_taxon_key=key, + observation_count=count, + raw=data, + ) + ) + return species + + +DEFAULT_GADM_LEVEL = 1 + + +def reverse_geocode_gadm( + latitude: float, + longitude: float, + level: int = DEFAULT_GADM_LEVEL, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + session=None, +) -> str | None: + """Resolve a point to the GADM region id that contains it, at the requested level. + + Level 1 is state/province ("USA.46_1"), level 2 is county/district + ("USA.46.14_1"). Returns the gid, or None when no GADM polygon of that level + contains the point. This derives a `region_code` for a project or site from a + deployment's stored latitude/longitude, so a regional taxa list can be built + without anyone typing a region code by hand (issue #1364, path A3). + + Matching is on the gid shape rather than the response's source string: a level-N + GADM gid has N dot-separated segments after the country and ends with the version + suffix "_1" (level 0, the bare country code, has no suffix and is never returned). + """ + session = session or create_session() + response = session.get( + f"{GBIF_API_BASE}/geocode/reverse", + params={"lat": latitude, "lng": longitude}, + timeout=timeout, + ) + response.raise_for_status() + for item in response.json(): + gid = item.get("id", "") + segments = gid.split(".") + if gid.endswith("_1") and len(segments) == level + 1: + return gid + return None diff --git a/ami/main/services/regional_taxa.py b/ami/main/services/regional_taxa.py new file mode 100644 index 000000000..8d7b5fac4 --- /dev/null +++ b/ami/main/services/regional_taxa.py @@ -0,0 +1,481 @@ +"""Build a project taxa list from a geographic region. + +Fetches the species recorded in a region from one or more external biodiversity +databases (currently GBIF; iNaturalist can be added behind the same protocol), maps +them onto Antenna `Taxon` rows, restricts the result (by default) to species some +classifier can actually predict, and saves a project-scoped `TaxaList`. This is the +one place the logic lives; management commands, admin actions, and API endpoints are +thin wrappers around `generate_regional_taxa_list()`. See issue #1364 and the +accompanying design/implementation-plan docs for the full rationale. + +The one design rule everything here is built around: when more than one source is +queried, a species present in ANY source is a candidate for the regional list. +Sources are combined with a wide union, never an intersection — querying a second +source can only grow the candidate set. Model coverage (does some classifier know +this species) is a separate, later axis applied after mapping to `Taxon`, not part of +how sources combine. +""" + +from __future__ import annotations + +import dataclasses +import logging +import typing + +from ..models import RegionSource, Taxon, TaxonRank +from . import taxon_coverage +from .taxonomy import create_taxon, get_or_create_root_taxon + +if typing.TYPE_CHECKING: + from ami.ml.models.algorithm import Algorithm + + from ..models import Project + +logger = logging.getLogger(__name__) + + +# --- Source abstraction ------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class SourceSpecies: + """One species as reported by ONE source for a region. + + The merge step (`merge_source_species`) concatenates these across sources and + deduplicates on a canonical key, unioning provenance. Fields are deliberately + source-agnostic so a new source only has to populate what it knows; `raw` carries + the untouched source payload for fields not yet promoted to the dataclass (e.g. + the genus/family hierarchy GBIF's species endpoint returns, used to build a + richer `Taxon` parent chain when creating a missing taxon). + """ + + source: str + scientific_name: str + rank: str | None = None + gbif_taxon_key: int | None = None + inat_taxon_id: int | None = None + observation_count: int | None = None + raw: dict | None = None + + +@dataclasses.dataclass +class MergedSpecies: + """One species after the wide union merge, with per-source provenance preserved.""" + + scientific_name: str + rank: str | None + gbif_taxon_key: int | None + inat_taxon_id: int | None + sources: set[str] + observation_counts: dict[str, int] + contributing: list[SourceSpecies] + + +@dataclasses.dataclass(frozen=True) +class TaxonScope: + """Source-specific root-taxon identifiers for a scope like "Lepidoptera", so a + caller names the scope once and each source translates it to its own key.""" + + label: str + gbif_taxon_key: int | None = None + inat_taxon_id: int | None = None + + +# GBIF backbone taxonKey 797 / iNaturalist taxon_id 47157 for Lepidoptera, verified +# live against both APIs in the #1364 Phase 0 spike (see docs/claude/analysis). +LEPIDOPTERA_SCOPE = TaxonScope(label="Lepidoptera", gbif_taxon_key=797, inat_taxon_id=47157) + + +class RegionalSpeciesSource(typing.Protocol): + source_key: str + + def fetch_species(self, region_code: str, taxon_scope: TaxonScope) -> list[SourceSpecies]: + """Return every species the source records in `region_code`, within + `taxon_scope`. Paginates internally; raises on transport/HTTP error rather + than returning a partial list silently.""" + ... + + +# --- Wide merge ----------------------------------------------------------------- + + +def _normalize_name(name: str) -> str: + return " ".join(name.casefold().split()) + + +def _dedup_keys(row: SourceSpecies) -> list[tuple[str, object]]: + """Every key `row` could be matched on, in dedup precedence order: gbif, inat, + name. A row usually supplies a subset of these.""" + keys: list[tuple[str, object]] = [] + if row.gbif_taxon_key: + keys.append(("gbif", row.gbif_taxon_key)) + if row.inat_taxon_id: + keys.append(("inat", row.inat_taxon_id)) + if row.scientific_name: + keys.append(("name", _normalize_name(row.scientific_name))) + return keys + + +def merge_source_species(per_source: list[list[SourceSpecies]]) -> list[MergedSpecies]: + """Concatenate species across sources and deduplicate on a canonical key, + UNIONING provenance. This is a wide join, never an intersection: a species in ANY + source survives. Two source rows collapse into one MergedSpecies when they share + a canonical key (gbif_taxon_key, then inat_taxon_id, then normalized name — first + match wins). A name-only collision with conflicting external keys keeps both keys + (logged as a provenance warning) rather than silently dropping one. + + Output is stable-ordered by descending max observation count, then name, so runs + are reproducible and diffs are readable. + """ + key_index: dict[tuple[str, object], MergedSpecies] = {} + order: list[MergedSpecies] = [] + + for species_list in per_source: + for row in species_list: + keys = _dedup_keys(row) + if not keys: + continue + + existing = next((key_index[k] for k in keys if k in key_index), None) + if existing is None: + merged = MergedSpecies( + scientific_name=row.scientific_name, + rank=row.rank, + gbif_taxon_key=row.gbif_taxon_key, + inat_taxon_id=row.inat_taxon_id, + sources={row.source}, + observation_counts=( + {row.source: row.observation_count} if row.observation_count is not None else {} + ), + contributing=[row], + ) + order.append(merged) + for k in keys: + key_index[k] = merged + continue + + existing.sources.add(row.source) + if row.observation_count is not None: + existing.observation_counts[row.source] = row.observation_count + existing.contributing.append(row) + + if row.gbif_taxon_key and existing.gbif_taxon_key and row.gbif_taxon_key != existing.gbif_taxon_key: + logger.warning( + "Regional species merge: %r has conflicting GBIF keys (%s vs %s); keeping both as provenance", + row.scientific_name, + existing.gbif_taxon_key, + row.gbif_taxon_key, + ) + elif row.gbif_taxon_key and not existing.gbif_taxon_key: + existing.gbif_taxon_key = row.gbif_taxon_key + key_index[("gbif", row.gbif_taxon_key)] = existing + + if row.inat_taxon_id and existing.inat_taxon_id and row.inat_taxon_id != existing.inat_taxon_id: + logger.warning( + "Regional species merge: %r has conflicting iNat ids (%s vs %s); keeping both as provenance", + row.scientific_name, + existing.inat_taxon_id, + row.inat_taxon_id, + ) + elif row.inat_taxon_id and not existing.inat_taxon_id: + existing.inat_taxon_id = row.inat_taxon_id + key_index[("inat", row.inat_taxon_id)] = existing + + for k in keys: + key_index.setdefault(k, existing) + + order.sort(key=lambda m: (-(max(m.observation_counts.values()) if m.observation_counts else 0), m.scientific_name)) + return order + + +# --- Mapping to Taxon ------------------------------------------------------------- + + +@dataclasses.dataclass +class MappingOutcome: + matched: list[tuple[MergedSpecies, Taxon]] + created: list[Taxon] + unmatched_names: list[str] + + +_HIERARCHY_FIELDS = ("kingdom", "phylum", "class", "order", "family", "subfamily", "tribe", "genus") + + +def _taxon_data_from_merged(species: MergedSpecies) -> dict: + """Build the row-shaped dict `create_taxon()` expects (see + ami.main.services.taxonomy), from whatever hierarchy fields a contributing + source's raw payload carries (e.g. GBIF's genus/family/order), falling back to + just the species-level name when a source gives no hierarchy — still a valid, + if flat, Taxon parented directly under the root. + """ + taxon_data: dict = {} + for row in species.contributing: + if not row.raw: + continue + for field in _HIERARCHY_FIELDS: + value = row.raw.get(field) + if value and field not in taxon_data: + taxon_data[field] = value + + rank_key = (species.rank or TaxonRank.SPECIES.name).lower() + taxon_data[rank_key] = species.scientific_name + if species.gbif_taxon_key: + taxon_data["gbif_taxon_key"] = species.gbif_taxon_key + if species.inat_taxon_id: + taxon_data["inat_taxon_id"] = species.inat_taxon_id + return taxon_data + + +def map_to_taxa(merged: list[MergedSpecies], *, create_missing: bool, dry_run: bool) -> MappingOutcome: + """Resolve each MergedSpecies to a Taxon. Match precedence: gbif_taxon_key, then + inat_taxon_id, then exact Taxon.name. On no match: create via the rank-hierarchy + builder when create_missing, else record the name as unmatched for human review. + Never mutates on dry_run — unmatched species that would be created are + represented as unsaved Taxon instances instead, so Result counts still reflect + what a real run would do. + """ + gbif_keys = {m.gbif_taxon_key for m in merged if m.gbif_taxon_key} + inat_ids = {m.inat_taxon_id for m in merged if m.inat_taxon_id} + names = {m.scientific_name for m in merged} + + by_gbif_key = ( + {t.gbif_taxon_key: t for t in Taxon.objects.filter(gbif_taxon_key__in=gbif_keys)} if gbif_keys else {} + ) + by_inat_id = {t.inat_taxon_id: t for t in Taxon.objects.filter(inat_taxon_id__in=inat_ids)} if inat_ids else {} + by_name = {t.name: t for t in Taxon.objects.filter(name__in=names)} if names else {} + + matched: list[tuple[MergedSpecies, Taxon]] = [] + created: list[Taxon] = [] + unmatched_names: list[str] = [] + root_taxon: Taxon | None = None + + for species in merged: + taxon = None + if species.gbif_taxon_key and species.gbif_taxon_key in by_gbif_key: + taxon = by_gbif_key[species.gbif_taxon_key] + elif species.inat_taxon_id and species.inat_taxon_id in by_inat_id: + taxon = by_inat_id[species.inat_taxon_id] + elif species.scientific_name in by_name: + taxon = by_name[species.scientific_name] + + if taxon is not None: + matched.append((species, taxon)) + continue + + if not create_missing: + unmatched_names.append(species.scientific_name) + continue + + if dry_run: + created.append( + Taxon( + name=species.scientific_name, + rank=(species.rank or TaxonRank.SPECIES.name).upper(), + gbif_taxon_key=species.gbif_taxon_key, + inat_taxon_id=species.inat_taxon_id, + ) + ) + continue + + if root_taxon is None: + root_taxon = get_or_create_root_taxon() + _created_taxa, _updated_taxa, specific_taxon = create_taxon(_taxon_data_from_merged(species), root_taxon) + created.append(specific_taxon) + + return MappingOutcome(matched=matched, created=created, unmatched_names=unmatched_names) + + +# --- Model coverage ---------------------------------------------------------------- + + +@dataclasses.dataclass +class CoverageOutcome: + covered: list[Taxon] + uncovered: list[Taxon] + + +def apply_model_coverage(mapping: MappingOutcome, *, dry_run: bool) -> CoverageOutcome: + """Partition mapped taxa into model-covered vs. uncovered using the persisted + has_model_coverage relationship (ami.main.services.taxon_coverage) — not a live + recompute for taxa that already existed before this run, whose coverage flag is + kept fresh by the Algorithm.save() hook / refresh_taxon_model_coverage command. + + Newly created taxa are the one exception: they have no coverage relationship yet + (they didn't exist for the hook to act on), so a real (non-dry_run) run refreshes + coverage before partitioning them. dry_run never writes, so newly-would-be-created + taxa are checked read-only against current category map labels instead. + """ + matched_taxa = [taxon for _, taxon in mapping.matched] + created_taxa = mapping.created + + if dry_run: + would_cover = taxon_coverage.names_covered_by_any_algorithm({t.name for t in created_taxa}) + covered = [t for t in matched_taxa if t.has_model_coverage] + covered += [t for t in created_taxa if t.name in would_cover] + uncovered = [t for t in matched_taxa if not t.has_model_coverage] + uncovered += [t for t in created_taxa if t.name not in would_cover] + return CoverageOutcome(covered=covered, uncovered=uncovered) + + if created_taxa: + # Give just-created taxa a coverage state before partitioning them. A targeted + # refresh (only these taxa) keeps the --all-projects backfill from paying a + # full per-algorithm rebuild on every project that creates a taxon. + taxon_coverage.refresh_coverage_for_taxa([t.pk for t in created_taxa]) + fresh_by_id = {t.pk: t for t in Taxon.objects.filter(pk__in=[t.pk for t in created_taxa])} + created_taxa = [fresh_by_id[t.pk] for t in created_taxa if t.pk in fresh_by_id] + + all_taxa = matched_taxa + created_taxa + covered = [t for t in all_taxa if t.has_model_coverage] + uncovered = [t for t in all_taxa if not t.has_model_coverage] + return CoverageOutcome(covered=covered, uncovered=uncovered) + + +# --- Core service -------------------------------------------------------------------- + + +@dataclasses.dataclass +class RegionalTaxaResult: + region_source: str + region_code: str + taxa_list_id: int | None + list_created: bool + # --- source union --- + regional_total: int + per_source_counts: dict[str, int] + # --- DB presence & model coverage --- + already_in_db: int + created_taxa: int + model_covered: int + regional_no_model_coverage: int + saved_list_size: int + # --- optional single-classifier report (reporting only; never filters the list) --- + in_classifier_labels: int | None + not_in_classifier: int | None + # --- review --- + unmatched_names: list[str] + dry_run: bool + + +def _default_sources(region_source: str) -> list[RegionalSpeciesSource]: + if region_source == RegionSource.GBIF_GADM: + from .gbif import GBIFRegionalSource + + return [GBIFRegionalSource()] + raise ValueError( + f"No default regional species source is registered for region_source={region_source!r}. " + "Pass sources= explicitly, or use RegionSource.GBIF_GADM (the only source implemented so far)." + ) + + +def generate_regional_taxa_list( + *, + region_source: str, + region_code: str, + project: Project | None = None, + classifier: Algorithm | None = None, + taxon_scope: TaxonScope | None = None, + sources: list[RegionalSpeciesSource] | None = None, + include_uncovered: bool = False, + create_missing: bool = True, + name: str | None = None, + dry_run: bool = False, +) -> RegionalTaxaResult: + """Fetch the species recorded in a region, map them to Taxon rows, and save them + as a project-scoped TaxaList. + + `sources` is the dependency-injection seam: pass stubbed sources in tests so + nothing hits the network. When omitted, the default source client for + `region_source` is used (currently GBIF only). + + Idempotent: a second run for the same (name, project) updates the same TaxaList + rather than creating a duplicate — re-running never creates duplicate lists or + duplicate Taxon rows (name-unique + external-key matching handle that). + """ + from ..models import TaxaList + + taxon_scope = taxon_scope or LEPIDOPTERA_SCOPE + resolved_sources = sources if sources is not None else _default_sources(region_source) + + per_source_species = [source.fetch_species(region_code, taxon_scope) for source in resolved_sources] + per_source_counts = { + source.source_key: len(species) for source, species in zip(resolved_sources, per_source_species) + } + + merged = merge_source_species(per_source_species) + regional_total = len(merged) + + mapping = map_to_taxa(merged, create_missing=create_missing, dry_run=dry_run) + coverage = apply_model_coverage(mapping, dry_run=dry_run) + + kept_taxa = list(coverage.covered) + if include_uncovered: + kept_taxa += coverage.uncovered + + in_classifier_labels: int | None = None + not_in_classifier: int | None = None + if classifier is not None and classifier.category_map is not None: + classifier_labels = set(classifier.category_map.labels) + in_classifier_labels = sum(1 for taxon in kept_taxa if taxon.name in classifier_labels) + not_in_classifier = len(kept_taxa) - in_classifier_labels + + list_name = name or f"{region_code} ({region_source})" + taxa_list_id: int | None = None + list_created = False + + if not dry_run: + taxa_list, list_created = TaxaList.objects.get_or_create_for_project(name=list_name, project=project) + taxa_list.taxa.set(kept_taxa) + taxa_list_id = taxa_list.pk + + return RegionalTaxaResult( + region_source=region_source, + region_code=region_code, + taxa_list_id=taxa_list_id, + list_created=list_created, + regional_total=regional_total, + per_source_counts=per_source_counts, + already_in_db=len(mapping.matched), + created_taxa=len(mapping.created), + model_covered=len(coverage.covered), + regional_no_model_coverage=len(coverage.uncovered), + saved_list_size=len(kept_taxa), + in_classifier_labels=in_classifier_labels, + not_in_classifier=not_in_classifier, + unmatched_names=mapping.unmatched_names, + dry_run=dry_run, + ) + + +def derive_region_for_project( + project: Project, + *, + region_source: str = RegionSource.GBIF_GADM.value, + level: int = 1, + geocoder: typing.Callable[..., str | None] | None = None, +) -> tuple[str, str] | None: + """Derive a (region_source, region_code) for a project from a representative + deployment's coordinates (issue #1364, path A3). + + This is what lets the `--all-projects` backfill run without anyone entering a + region by hand: it reverse-geocodes the first deployment that has coordinates. + Returns None when the project has no located deployment or the point falls + outside any GADM region of `level`, so the caller can skip that project. GBIF/GADM + is the only supported source for now. `geocoder` is a test seam — inject a stub to + avoid a network call. + """ + if region_source != RegionSource.GBIF_GADM: + raise ValueError(f"derive_region_for_project supports GBIF/GADM only, got {region_source!r}") + + deployment = project.deployments.filter(latitude__isnull=False, longitude__isnull=False).order_by("pk").first() + if deployment is None: + return None + + if geocoder is None: + from .gbif import reverse_geocode_gadm + + geocoder = reverse_geocode_gadm + + region_code = geocoder(deployment.latitude, deployment.longitude, level=level) + if not region_code: + return None + return (region_source, region_code) diff --git a/ami/main/services/taxon_coverage.py b/ami/main/services/taxon_coverage.py new file mode 100644 index 000000000..e48cd64a0 --- /dev/null +++ b/ami/main/services/taxon_coverage.py @@ -0,0 +1,130 @@ +"""Keep `Taxon.covered_by_algorithms` / `Taxon.has_model_coverage` in sync with what +classifiers can actually predict. + +Coverage is derived data: a taxon is "covered" by an algorithm when the taxon's name +appears in that algorithm's category map labels — the same `Taxon.name == label` join +`AlgorithmCategoryMap.with_taxa()` uses at classification time (`ami/ml/models/algorithm.py`). +It is computed here and persisted so callers (the regional taxa-list service, the admin, +future masking auto-resolution) can filter/read it without re-deriving the join on every +request. It is never recomputed on a per-occurrence or per-classification basis — only +when a category map's label set changes (see the `Algorithm.save()` hook) or on demand via +the `refresh_taxon_model_coverage` management command, so a brief lag between a label-set +change and this flag updating is expected and harmless (masking itself still resolves names +live via `with_taxa()`; this flag is a filtering/reporting convenience, not the masking path). +""" + +import typing + +from ..models import Taxon + +if typing.TYPE_CHECKING: + from ami.ml.models.algorithm import Algorithm + + +def refresh_algorithm_coverage(algorithm: "Algorithm") -> None: + """Recompute which Taxon rows `algorithm` covers, from its category map's label + set, and persist has_model_coverage for every taxon whose membership changed + (both the ones that gained coverage and the ones that lost it). + + Call this whenever an algorithm starts, stops, or changes which category map it + uses — the `Algorithm.save()` hook does this automatically when `category_map_id` + changes. An algorithm with no category map covers nothing. + """ + category_map = algorithm.category_map + new_taxon_ids: set[int] = set() + if category_map is not None and category_map.labels: + new_taxon_ids = set(Taxon.objects.filter(name__in=category_map.labels).values_list("pk", flat=True)) + + old_taxon_ids = set(algorithm.covered_taxa.values_list("pk", flat=True)) + algorithm.covered_taxa.set(Taxon.objects.filter(pk__in=new_taxon_ids)) + _resync_has_model_coverage(old_taxon_ids | new_taxon_ids) + + +def refresh_all_algorithm_coverage() -> int: + """Rebuild the model-coverage relationship for every algorithm that has a + category map. This is the full-rebuild path: the `refresh_taxon_model_coverage` + management command uses it for the initial backfill and as a repair tool, and the + regional taxa-list service uses it to give freshly created Taxon rows a coverage + state before partitioning them (they have none yet, having just been created). + + Cost is one query per algorithm (best-guess, not measured against a + production-sized category map / algorithm count) — see plan issue #1364 §14 for + the open question of whether this needs to move off the request path for large + deployments. + + Returns the number of algorithms processed. + """ + from ami.ml.models.algorithm import Algorithm + + algorithms = list(Algorithm.objects.filter(category_map__isnull=False).select_related("category_map")) + for algorithm in algorithms: + refresh_algorithm_coverage(algorithm) + return len(algorithms) + + +def refresh_coverage_for_taxa(taxon_ids: typing.Iterable[int]) -> None: + """Compute and persist model coverage for exactly these taxa, without a full rebuild. + + For taxa the regional service just created, this links each to any algorithm whose + category-map labels include its name, then resyncs has_model_coverage for only + these taxa. It touches only the given taxa and loads only the category maps whose + labels actually overlap their names — unlike refresh_all_algorithm_coverage, which + rewrites the whole covered_taxa relation for every algorithm and is meant for full + backfills or repair. This keeps the per-run cost of the ``--all-projects`` backfill + proportional to the newly created taxa, not to the total algorithm/label count. + """ + from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap + + taxon_ids = list(taxon_ids) + if not taxon_ids: + return + + ids_by_name: dict[str, list[int]] = {} + for pk, name in Taxon.objects.filter(pk__in=taxon_ids).values_list("pk", "name"): + ids_by_name.setdefault(name, []).append(pk) + names = set(ids_by_name) + + for category_map in AlgorithmCategoryMap.objects.filter(labels__overlap=list(names)): + matched_ids = [pk for label in names.intersection(category_map.labels) for pk in ids_by_name[label]] + if not matched_ids: + continue + for algorithm in Algorithm.objects.filter(category_map=category_map): + algorithm.covered_taxa.add(*matched_ids) + + _resync_has_model_coverage(taxon_ids) + + +def names_covered_by_any_algorithm(names: set[str]) -> set[str]: + """Read-only check of which of `names` appear in some algorithm's category map + labels, without creating or persisting anything. Used only to simulate the + model-coverage partition for species that don't have a Taxon row yet (the + dry_run path of `generate_regional_taxa_list`, which must not mutate the DB). + """ + from ami.ml.models.algorithm import AlgorithmCategoryMap + + if not names: + return set() + covered: set[str] = set() + for category_map in AlgorithmCategoryMap.objects.only("labels"): + covered |= names.intersection(category_map.labels) + if covered == names: + break + return covered + + +def _resync_has_model_coverage(taxon_ids: typing.Iterable[int]) -> None: + """Set has_model_coverage = (covered_by_algorithms is non-empty) for exactly + these taxa, in two bulk UPDATEs — no per-row queries.""" + taxon_ids = list(taxon_ids) + if not taxon_ids: + return + covered_ids = set( + Taxon.objects.filter(pk__in=taxon_ids, covered_by_algorithms__isnull=False) + .values_list("pk", flat=True) + .distinct() + ) + uncovered_ids = set(taxon_ids) - covered_ids + if covered_ids: + Taxon.objects.filter(pk__in=covered_ids).update(has_model_coverage=True) + if uncovered_ids: + Taxon.objects.filter(pk__in=uncovered_ids).update(has_model_coverage=False) diff --git a/ami/main/services/taxonomy.py b/ami/main/services/taxonomy.py new file mode 100644 index 000000000..a46cc376a --- /dev/null +++ b/ami/main/services/taxonomy.py @@ -0,0 +1,175 @@ +"""Shared taxon-creation logic: build a Taxon and its rank hierarchy from a row of +taxonomic data (species/genus/family/...), creating any missing ancestors. + +Extracted from ``ami.main.management.commands.import_taxa`` (the ``Command.create_taxon`` +method and ``get_or_create_root_taxon``) so the `import_taxa` command and the regional +taxa-list service (`ami.main.services.regional_taxa`) share one implementation instead +of each re-deriving the rank-walk logic. Behaviour is unchanged from the original +command; see `ami/main/management/commands/import_taxa.py::Command` for the CSV/JSON +import entry point that still calls these functions. +""" + +import logging + +from ..models import Taxon, TaxonRank + +RANK_CHOICES = [rank for rank in TaxonRank] + +logger = logging.getLogger(__name__) + + +def get_or_create_root_taxon() -> Taxon: + """ + Important! This is where the root taxon is configured. + """ + root_taxon_parent, created = Taxon.objects.get_or_create( + name="Arthropoda", rank=TaxonRank.PHYLUM.name, defaults={"ordering": 0} + ) + if created: + logger.info(f"Created root taxon {root_taxon_parent}") + else: + logger.info(f"Found existing root taxon {root_taxon_parent}") + if root_taxon_parent.parent: + # If the root taxon has a parent, remove it + # Otherwise, the root taxon will not be the root and there will be recursion issues + root_taxon_parent.parent = None + root_taxon_parent.save() + return root_taxon_parent + + +def create_taxon(taxon_data: dict, root_taxon_parent: Taxon) -> tuple[set[Taxon], set[Taxon], Taxon]: + taxa_in_row = [] + created_taxa = set() + updated_taxa = set() + + # parent_must_match = ["SPECIES"]#], "SUBSPECIES", "VARIETY", "FORM"] + parent_taxon = root_taxon_parent + + for i, rank in enumerate(sorted(RANK_CHOICES)): + logger.debug(f"Checking rank {rank} {i} of {len(RANK_CHOICES)}") + logger.debug(f"Current parent taxon: {parent_taxon}") + # Create all parents and parents of parents + # Assume ranks are in order of rank + if rank.name.lower() in taxon_data.keys() and taxon_data[rank.name.lower()]: + name = taxon_data[rank.name.lower()] + gbif_taxon_key = taxon_data.get("gbif_taxon_key", None) + rank = rank.name.upper() + logger.debug(f"Taxon found in incoming row {i}: {rank} {name} (GBIF: {gbif_taxon_key})") + + # Look up existing taxon by name only, since names must be unique. + # If the taxon already exists, use it and maybe update it + taxon, created = Taxon.objects.get_or_create( + name=name, + defaults=dict( + rank=rank, + gbif_taxon_key=gbif_taxon_key, + parent=parent_taxon, + ), + ) + taxa_in_row.append(taxon) + + if created: + logger.debug(f"Created new taxon #{taxon.id} {taxon} ({taxon.rank})") + created_taxa.add(taxon) + else: + logger.debug(f"Using existing taxon #{taxon.id} {taxon} ({taxon.rank})") + + # Add or update the rank of the taxon based on incoming data + if not taxon.rank or taxon.rank != rank: + if not created: + logger.warning(f"Rank of existing {taxon} is changing from {taxon.rank} to {rank}") + taxon.rank = rank + taxon.save(update_calculated_fields=False) + if not created: + updated_taxa.add(taxon) + + # Add or update the parent of the taxon based on incoming data + # if the incoming parent is more specific than the existing parent + # (e.g. if the existing parent is Lepidoptera and the existing parent is a family) + if not taxon.parent or parent_taxon.get_rank() > taxon.parent.get_rank(): + parent = parent_taxon or root_taxon_parent + if parent == taxon: + logger.debug(f"Parent of {taxon} is itself, changing to (or keeping as) None") + parent = None + if taxon.parent != parent: + if not created: + logger.warning(f"Changing parent of {taxon} from {taxon.parent} to more specific {parent}") + taxon.parent = parent + taxon.save(update_calculated_fields=False) + if not created: + updated_taxa.add(taxon) + + parent_taxon = taxon + logger.debug(f"Next parent taxon: {parent_taxon.rank} {parent_taxon}") + else: + logger.debug(f"Did not find {rank} in incoming row, checking next rank") + + accepted_name = taxon_data.get("synonym_of", None) + + if not taxa_in_row: + raise ValueError(f"Could not find any ranks in {taxon_data}") + + # Make sure incoming taxa are sorted by rank + taxa_in_row = sorted(taxa_in_row, key=lambda taxon: taxon.get_rank()) + + logger.debug(f"Found {len(taxa_in_row)} taxa in row: {taxa_in_row}") + + specific_taxon = taxa_in_row[-1] + expected_specific_taxon_ranks = TaxonRank.SPECIES, TaxonRank.GENUS + if specific_taxon.get_rank() not in expected_specific_taxon_ranks: + logger.warning(f"Assuming the most specific taxon of this row is: {specific_taxon} {specific_taxon.rank}") + + specific_taxon_columns = [ + "author", + "authorship_date", + "gbif_taxon_key", + "bold_taxon_bin", + "inat_taxon_id", + "common_name_en", + "notes", + "sort_phylogeny", + "fieldguide_id", + "cover_image_url", + "cover_image_credit", + ] + + is_new = specific_taxon in created_taxa + needs_update = False + for column in specific_taxon_columns: + if column in taxon_data: + existing_value = getattr(specific_taxon, column) + incoming_value = taxon_data[column] + if existing_value != incoming_value: + if incoming_value is None: + # Don't overwrite existing values with None. + # This could potentially be a command line option to allow users to clear values. + logger.debug(f"Not changing {column} of {specific_taxon} from {existing_value} to None") + continue + if not is_new: + logger.info(f"Changing {column} of {specific_taxon} to from {existing_value} to {incoming_value}") + setattr(specific_taxon, column, taxon_data[column]) + needs_update = True + if needs_update: + specific_taxon.save(update_calculated_fields=False) + if not is_new: + # raise ValueError(f"TAXON DATA CHANGED for {specific_taxon}") + logger.warning(f"TAXON DATA CHANGED for existing {specific_taxon} ({specific_taxon.id})") + updated_taxa.add(specific_taxon) + + if accepted_name: + accepted_taxon, created = Taxon.objects.get_or_create( + name=accepted_name, + rank=specific_taxon.rank, + defaults={"parent": parent_taxon}, + ) + if created: + logger.info(f"Created accepted taxon {accepted_taxon}") + created_taxa.add(accepted_taxon) + + if specific_taxon.synonym_of != accepted_taxon: + logger.info(f"Setting synonym_of of {specific_taxon} to {accepted_taxon}") + specific_taxon.synonym_of = accepted_taxon + specific_taxon.save() + updated_taxa.add(specific_taxon) + + return created_taxa, updated_taxa, specific_taxon diff --git a/ami/main/tasks.py b/ami/main/tasks.py index 16f927a3f..37c0bc999 100644 --- a/ami/main/tasks.py +++ b/ami/main/tasks.py @@ -23,3 +23,65 @@ def refresh_project_cached_counts(project_id: int) -> None: logger.info(f"Refreshing cached counts for project {project.pk} ({project.name})") project.update_related_calculated_fields() + + +@celery_app.task(ignore_result=True) +def generate_regional_taxa_list_task( + *, + project_id: int, + region_source: str, + region_code: str, + site_id: int | None = None, + classifier_id: int | None = None, + include_uncovered: bool = False, +) -> None: + """Build a regional taxa list for a project (or one of its sites) and link it. + + Runs off the request path because the external biodiversity-database fetch can + take tens of seconds — too long for an admin request. Enqueued from the Project + and Site admin actions. On success the generated list is attached to + ``project.default_taxa_list`` (project scope) or ``site.taxa_list`` (site scope), + which is what the masking auto-resolution later reads. See + ``ami.main.services.regional_taxa`` and issue #1364. + """ + from ami.main.models import Project, Site, TaxaList + from ami.main.services import regional_taxa + + try: + project = Project.objects.get(pk=project_id) + except Project.DoesNotExist: + logger.warning(f"Project {project_id} not found; skipping regional taxa-list generation") + return + + classifier = None + if classifier_id: + from ami.ml.models.algorithm import Algorithm + + classifier = Algorithm.objects.filter(pk=classifier_id).first() + + result = regional_taxa.generate_regional_taxa_list( + project=project, + region_source=region_source, + region_code=region_code, + classifier=classifier, + include_uncovered=include_uncovered, + ) + if result.taxa_list_id is None: + return + + taxa_list = TaxaList.objects.get(pk=result.taxa_list_id) + if site_id is not None: + site = Site.objects.filter(pk=site_id).first() + if site is not None: + site.taxa_list = taxa_list + site.save(update_fields=["taxa_list"]) + else: + project.default_taxa_list = taxa_list + project.save(update_fields=["default_taxa_list"]) + + logger.info( + "Regional taxa list %s (%s taxa) linked to %s", + taxa_list.pk, + result.saved_list_size, + f"site {site_id}" if site_id else f"project {project_id}", + ) diff --git a/ami/main/tests_regional_taxa.py b/ami/main/tests_regional_taxa.py new file mode 100644 index 000000000..f11c9b708 --- /dev/null +++ b/ami/main/tests_regional_taxa.py @@ -0,0 +1,877 @@ +"""Tests for the regional taxa-list service (issue #1364, Phase 1). + +Covers: GBIF source-client parsing, the wide-union merge, mapping merged species to +Taxon rows, the model-coverage relationship and its refresh hook/command, and the +generate_regional_taxa_list() orchestration (idempotency, the report-only classifier +overlay, and the default-covered-only vs. include_uncovered behaviour). + +Every test here uses a stubbed RegionalSpeciesSource or a monkeypatched HTTP session — +nothing exercises the network. +""" + +from unittest import mock + +from django.contrib.messages.storage.fallback import FallbackStorage +from django.core.management import call_command +from django.core.management.base import CommandError +from django.test import RequestFactory, TestCase +from guardian.shortcuts import assign_perm +from rest_framework.test import APITestCase + +from ami.main.models import Deployment, Project, RegionSource, Site, TaxaList, Taxon, TaxonRank +from ami.main.services.gbif import GBIFRegionalSource, reverse_geocode_gadm +from ami.main.services.regional_taxa import ( + LEPIDOPTERA_SCOPE, + MergedSpecies, + SourceSpecies, + apply_model_coverage, + derive_region_for_project, + generate_regional_taxa_list, + map_to_taxa, + merge_source_species, +) +from ami.main.services.taxon_coverage import refresh_coverage_for_taxa +from ami.main.tasks import generate_regional_taxa_list_task +from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap +from ami.users.models import User + + +class _FakeResponse: + def __init__(self, json_data, status_code=200): + self._json_data = json_data + self.status_code = status_code + + def raise_for_status(self): + pass + + def json(self): + return self._json_data + + +class _FakeGBIFSession: + """Stub for requests.Session covering GBIFRegionalSource's two endpoints: a + speciesKey-faceted occurrence search (paginated) and per-key species lookups. + Records every URL requested so tests can assert on pagination termination.""" + + def __init__(self): + self.calls: list[str] = [] + + def get(self, url, params=None, timeout=None): + self.calls.append(url) + if url.endswith("/occurrence/search"): + offset = params["facetOffset"] + if offset == 0: + counts = [{"name": "1001", "count": 42}, {"name": "1002", "count": 7}] + elif offset == 2: + counts = [{"name": "1003", "count": 3}] + else: + counts = [] + facets = [{"field": "SPECIES_KEY", "counts": counts}] if counts else [] + return _FakeResponse({"facets": facets}) + if url.endswith("/species/1001"): + return _FakeResponse( + {"canonicalName": "Vanessa atalanta", "rank": "SPECIES", "family": "Nymphalidae", "genus": "Vanessa"} + ) + if url.endswith("/species/1002"): + return _FakeResponse({"canonicalName": "Danaus plexippus", "rank": "SPECIES"}) + if url.endswith("/species/1003"): + # Simulates a speciesKey with no resolvable species record. + return _FakeResponse({}, status_code=404) + raise AssertionError(f"Unexpected GBIF URL requested in test: {url}") + + +class GBIFRegionalSourceParsingTest(TestCase): + def test_fetch_species_parses_facets_and_resolves_names(self): + """fetch_species pages the speciesKey facet until a partial page ends it, + resolves each key to a name (skipping keys GBIF can't resolve), and carries + the per-species observation count and raw hierarchy fields through.""" + source = GBIFRegionalSource(facet_page_size=2) + fake_session = _FakeGBIFSession() + + with mock.patch("ami.main.services.gbif.create_session", return_value=fake_session): + species = source.fetch_species("USA.46_1", LEPIDOPTERA_SCOPE) + + names = {s.scientific_name for s in species} + self.assertEqual(names, {"Vanessa atalanta", "Danaus plexippus"}) + + # Two facet pages (offset 0 full, offset 2 partial) — pagination stopped at + # the partial page rather than issuing a third, empty-page request. + occurrence_calls = [c for c in fake_session.calls if c.endswith("/occurrence/search")] + self.assertEqual(len(occurrence_calls), 2) + + atalanta = next(s for s in species if s.scientific_name == "Vanessa atalanta") + self.assertEqual(atalanta.gbif_taxon_key, 1001) + self.assertEqual(atalanta.observation_count, 42) + self.assertEqual(atalanta.raw["family"], "Nymphalidae") + + +class MergeSourceSpeciesTest(TestCase): + def test_union_keeps_species_present_in_only_one_source(self): + """A species reported by only one source still survives the merge — a + second source can only grow the candidate set, never narrow it.""" + gbif_only = SourceSpecies(source="gbif_gadm", scientific_name="Danaus plexippus", gbif_taxon_key=1) + inat_only = SourceSpecies(source="inat_place", scientific_name="Vanessa cardui", inat_taxon_id=2) + + merged = merge_source_species([[gbif_only], [inat_only]]) + + self.assertEqual({m.scientific_name for m in merged}, {"Danaus plexippus", "Vanessa cardui"}) + + def test_shared_gbif_key_collapses_to_one_row_with_unioned_provenance(self): + """Two sources reporting the same species (matched by gbif_taxon_key) merge + into one MergedSpecies carrying both sources and both observation counts.""" + gbif_row = SourceSpecies( + source="gbif_gadm", scientific_name="Vanessa atalanta", gbif_taxon_key=100, observation_count=50 + ) + inat_row = SourceSpecies( + source="inat_place", scientific_name="Vanessa atalanta", gbif_taxon_key=100, observation_count=30 + ) + + merged = merge_source_species([[gbif_row], [inat_row]]) + + self.assertEqual(len(merged), 1) + self.assertEqual(merged[0].sources, {"gbif_gadm", "inat_place"}) + self.assertEqual(merged[0].observation_counts, {"gbif_gadm": 50, "inat_place": 30}) + + def test_row_merges_via_whichever_of_its_keys_matches_first(self): + """Dedup checks a row's gbif key, then its inat key, then its name against + the index — a row whose gbif key is new but whose inat key matches an + existing group still collapses into that group instead of starting a new one.""" + first = SourceSpecies(source="gbif_gadm", scientific_name="Vanessa cardui", gbif_taxon_key=1, inat_taxon_id=9) + second = SourceSpecies( + source="inat_place", scientific_name="Vanessa cardui", gbif_taxon_key=2, inat_taxon_id=9 + ) + + merged = merge_source_species([[first], [second]]) + + self.assertEqual(len(merged), 1) + # The original gbif key is preserved rather than overwritten by the conflicting one. + self.assertEqual(merged[0].gbif_taxon_key, 1) + + def test_name_only_collision_with_conflicting_keys_keeps_both_and_logs(self): + """Two rows sharing only a normalized name but carrying different GBIF keys + merge into one row (name-only collision) rather than becoming two rows, and + the conflict is logged instead of silently dropping a key.""" + first = SourceSpecies(source="gbif_gadm", scientific_name="Vanessa cardui", gbif_taxon_key=1) + second = SourceSpecies(source="inat_place", scientific_name="Vanessa cardui", gbif_taxon_key=2) + + with self.assertLogs("ami.main.services.regional_taxa", level="WARNING"): + merged = merge_source_species([[first], [second]]) + + self.assertEqual(len(merged), 1) + self.assertEqual(merged[0].sources, {"gbif_gadm", "inat_place"}) + + +class MapToTaxaTest(TestCase): + def test_matches_existing_taxon_by_gbif_key(self): + existing = Taxon.objects.create(name="Papilio machaon", rank=TaxonRank.SPECIES.name, gbif_taxon_key=100) + species = MergedSpecies( + scientific_name="Papilio machaon", + rank="SPECIES", + gbif_taxon_key=100, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + + outcome = map_to_taxa([species], create_missing=True, dry_run=False) + + self.assertEqual(outcome.matched, [(species, existing)]) + self.assertEqual(outcome.created, []) + + def test_matches_existing_taxon_by_name(self): + existing = Taxon.objects.create(name="Danaus plexippus", rank=TaxonRank.SPECIES.name) + species = MergedSpecies( + scientific_name="Danaus plexippus", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + + outcome = map_to_taxa([species], create_missing=True, dry_run=False) + + self.assertEqual(outcome.matched, [(species, existing)]) + + def test_create_missing_creates_taxon_via_hierarchy_builder(self): + """When no existing Taxon matches, create_missing builds one (and its + ancestors) via the same rank-hierarchy builder import_taxa uses, from + whatever hierarchy fields a contributing source's raw payload carries.""" + contributing = SourceSpecies( + source="gbif_gadm", + scientific_name="Papilio glaucus", + gbif_taxon_key=555, + raw={"family": "Papilionidae", "genus": "Papilio"}, + ) + species = MergedSpecies( + scientific_name="Papilio glaucus", + rank="SPECIES", + gbif_taxon_key=555, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[contributing], + ) + + outcome = map_to_taxa([species], create_missing=True, dry_run=False) + + self.assertEqual(len(outcome.created), 1) + taxon = outcome.created[0] + self.assertEqual(taxon.name, "Papilio glaucus") + self.assertEqual(taxon.gbif_taxon_key, 555) + self.assertEqual(taxon.parent.name, "Papilio") + self.assertEqual(taxon.parent.parent.name, "Papilionidae") + + def test_create_missing_false_records_unmatched_name_without_creating(self): + species = MergedSpecies( + scientific_name="Unknown species", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + + outcome = map_to_taxa([species], create_missing=False, dry_run=False) + + self.assertEqual(outcome.unmatched_names, ["Unknown species"]) + self.assertEqual(outcome.created, []) + self.assertFalse(Taxon.objects.filter(name="Unknown species").exists()) + + def test_rerun_does_not_create_duplicate_taxon(self): + """Running map_to_taxa twice for the same species must not create a second + Taxon row — the second call matches the row the first call created, via the + existing-Taxon name lookup (Taxon.name is unique).""" + species = MergedSpecies( + scientific_name="Papilio glaucus", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + + map_to_taxa([species], create_missing=True, dry_run=False) + map_to_taxa([species], create_missing=True, dry_run=False) + + self.assertEqual(Taxon.objects.filter(name="Papilio glaucus").count(), 1) + + def test_dry_run_never_creates_a_taxon_row(self): + """dry_run must not mutate the DB — a species that would be created is + represented as an unsaved Taxon stand-in instead.""" + species = MergedSpecies( + scientific_name="Ghost Species", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + + outcome = map_to_taxa([species], create_missing=True, dry_run=True) + + self.assertEqual(len(outcome.created), 1) + self.assertIsNone(outcome.created[0].pk) + self.assertFalse(Taxon.objects.filter(name="Ghost Species").exists()) + + def test_matching_existing_taxa_is_batched_not_per_row(self): + """The existing-Taxon lookup is three bulk `__in` queries regardless of how + many merged species are being matched — never one query per species.""" + taxa = [ + Taxon.objects.create(name=f"Species {i}", rank=TaxonRank.SPECIES.name, gbif_taxon_key=1000 + i) + for i in range(5) + ] + merged = [ + MergedSpecies( + scientific_name=taxon.name, + rank="SPECIES", + gbif_taxon_key=taxon.gbif_taxon_key, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ) + for taxon in taxa + ] + + # gbif_taxon_key __in lookup + name __in lookup; no inat ids given, so that + # lookup is skipped. Neither query scales with len(merged). + with self.assertNumQueries(2): + outcome = map_to_taxa(merged, create_missing=False, dry_run=False) + + self.assertEqual(len(outcome.matched), 5) + + +class GenerateRegionalTaxaListTest(TestCase): + class StubSource: + def __init__(self, source_key: str, species: list[SourceSpecies]): + self.source_key = source_key + self._species = species + + def fetch_species(self, region_code, taxon_scope): + return list(self._species) + + def test_rerun_updates_same_taxa_list_not_a_duplicate(self): + project = Project.objects.create(name="Regional Test Project", create_defaults=False) + source = self.StubSource( + "gbif_gadm", [SourceSpecies(source="gbif_gadm", scientific_name="Colias eurytheme", gbif_taxon_key=1)] + ) + + result1 = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, + region_code="USA.46_1", + project=project, + sources=[source], + name="Vermont Moths", + include_uncovered=True, + ) + result2 = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, + region_code="USA.46_1", + project=project, + sources=[source], + name="Vermont Moths", + include_uncovered=True, + ) + + self.assertTrue(result1.list_created) + self.assertFalse(result2.list_created) + self.assertEqual(result1.taxa_list_id, result2.taxa_list_id) + self.assertEqual(TaxaList.objects.filter(name="Vermont Moths", projects=project).count(), 1) + self.assertEqual(Taxon.objects.filter(name="Colias eurytheme").count(), 1) + taxa_list = TaxaList.objects.get(pk=result2.taxa_list_id) + self.assertEqual(list(taxa_list.taxa.values_list("name", flat=True)), ["Colias eurytheme"]) + + def test_classifier_report_is_report_only_does_not_filter_list(self): + """Passing classifier= only populates in_classifier_labels/not_in_classifier; + the saved list is governed solely by the ordinary model-coverage rule.""" + project = Project.objects.create(name="Regional Test Project", create_defaults=False) + covered_taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + coverage_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + Algorithm.objects.create(name="Coverage Classifier", version=1, category_map=coverage_map) + covered_taxon.refresh_from_db() + self.assertTrue(covered_taxon.has_model_coverage) + + # A different classifier whose own labels do NOT include "Covered Species" — + # exercises that the report is scoped to the one classifier passed in, not + # to whichever classifier happened to give the taxon model coverage. + reporting_map = AlgorithmCategoryMap.objects.create( + labels=["Some Other Label"], data=[{"index": 0, "label": "Some Other Label"}] + ) + reporting_classifier = Algorithm.objects.create( + name="Reporting Classifier", version=1, category_map=reporting_map + ) + + source = self.StubSource( + "gbif_gadm", [SourceSpecies(source="gbif_gadm", scientific_name="Covered Species", gbif_taxon_key=1)] + ) + + result = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, + region_code="USA.46_1", + project=project, + sources=[source], + classifier=reporting_classifier, + ) + + self.assertEqual(result.saved_list_size, 1) + self.assertEqual(result.in_classifier_labels, 0) + self.assertEqual(result.not_in_classifier, 1) + + def test_default_saves_only_model_covered_species(self): + """A region with a mix of covered and uncovered species: the default run + saves only the covered ones. The uncovered species' Taxon row still gets + created (create_missing's default), it's just excluded from this list.""" + project = Project.objects.create(name="Regional Test Project", create_defaults=False) + covered_taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + coverage_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + Algorithm.objects.create(name="Coverage Classifier", version=1, category_map=coverage_map) + covered_taxon.refresh_from_db() + self.assertTrue(covered_taxon.has_model_coverage) + + source = self.StubSource( + "gbif_gadm", + [ + SourceSpecies(source="gbif_gadm", scientific_name="Covered Species", gbif_taxon_key=1), + SourceSpecies(source="gbif_gadm", scientific_name="Uncovered Species", gbif_taxon_key=2), + ], + ) + + result = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, region_code="USA.46_1", project=project, sources=[source] + ) + + self.assertEqual(result.regional_total, 2) + self.assertEqual(result.model_covered, 1) + self.assertEqual(result.regional_no_model_coverage, 1) + self.assertEqual(result.saved_list_size, 1) + taxa_list = TaxaList.objects.get(pk=result.taxa_list_id) + self.assertEqual(set(taxa_list.taxa.values_list("name", flat=True)), {"Covered Species"}) + self.assertTrue(Taxon.objects.filter(name="Uncovered Species").exists()) + + def test_include_uncovered_creates_and_flags_uncovered_species(self): + """Opting in keeps both buckets: covered species stay flagged True, and the + newly created uncovered species are flagged has_model_coverage=False with an + empty covered_by_algorithms — an honest "in the region, no model knows it yet".""" + project = Project.objects.create(name="Regional Test Project", create_defaults=False) + covered_taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + coverage_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + Algorithm.objects.create(name="Coverage Classifier", version=1, category_map=coverage_map) + covered_taxon.refresh_from_db() + + source = self.StubSource( + "gbif_gadm", + [ + SourceSpecies(source="gbif_gadm", scientific_name="Covered Species", gbif_taxon_key=1), + SourceSpecies(source="gbif_gadm", scientific_name="Uncovered Species", gbif_taxon_key=2), + ], + ) + + result = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, + region_code="USA.46_1", + project=project, + sources=[source], + include_uncovered=True, + ) + + self.assertEqual(result.saved_list_size, 2) + taxa_list = TaxaList.objects.get(pk=result.taxa_list_id) + self.assertEqual(set(taxa_list.taxa.values_list("name", flat=True)), {"Covered Species", "Uncovered Species"}) + + uncovered = Taxon.objects.get(name="Uncovered Species") + self.assertFalse(uncovered.has_model_coverage) + self.assertEqual(uncovered.covered_by_algorithms.count(), 0) + + covered_taxon.refresh_from_db() + self.assertTrue(covered_taxon.has_model_coverage) + + def test_dry_run_never_mutates_the_database(self): + project = Project.objects.create(name="Regional Test Project", create_defaults=False) + source = self.StubSource( + "gbif_gadm", [SourceSpecies(source="gbif_gadm", scientific_name="Ghost Species", gbif_taxon_key=1)] + ) + + result = generate_regional_taxa_list( + region_source=RegionSource.GBIF_GADM, + region_code="USA.46_1", + project=project, + sources=[source], + include_uncovered=True, + dry_run=True, + ) + + self.assertTrue(result.dry_run) + self.assertIsNone(result.taxa_list_id) + self.assertEqual(result.created_taxa, 1) + self.assertFalse(Taxon.objects.filter(name="Ghost Species").exists()) + self.assertFalse(TaxaList.objects.filter(projects=project).exists()) + + +class TaxonModelCoverageRefreshTest(TestCase): + def test_linking_a_category_map_hook_sets_coverage(self): + """Linking an algorithm to a category map that lists a taxon's name as a + label marks that taxon has_model_coverage=True and adds the algorithm to + its covered_by_algorithms, via the Algorithm.save() hook — no explicit + refresh call needed.""" + taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + category_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + + algorithm = Algorithm.objects.create(name="Test Classifier", version=1, category_map=category_map) + + taxon.refresh_from_db() + self.assertTrue(taxon.has_model_coverage) + self.assertIn(algorithm, taxon.covered_by_algorithms.all()) + self.assertIn(taxon, algorithm.covered_taxa.all()) + + def test_targeted_refresh_covers_only_the_given_taxa(self): + """A taxon created after a classifier exists is not seen by the save hook. + refresh_coverage_for_taxa (used by the regional service for just-created + rows) links only the named taxa to matching algorithms, leaving others + untouched — this is the cheap path that avoids a full per-algorithm rebuild.""" + category_map = AlgorithmCategoryMap.objects.create( + labels=["Late Species"], data=[{"index": 0, "label": "Late Species"}] + ) + algorithm = Algorithm.objects.create(name="Late Classifier", version=1, category_map=category_map) + # Both created AFTER the algorithm, so the save hook never linked them. + late = Taxon.objects.create(name="Late Species", rank=TaxonRank.SPECIES.name) + unlisted = Taxon.objects.create(name="Unlisted Species", rank=TaxonRank.SPECIES.name) + self.assertFalse(late.has_model_coverage) + + refresh_coverage_for_taxa([late.pk, unlisted.pk]) + + late.refresh_from_db() + unlisted.refresh_from_db() + self.assertTrue(late.has_model_coverage) + self.assertIn(algorithm, late.covered_by_algorithms.all()) + self.assertFalse(unlisted.has_model_coverage) + self.assertEqual(unlisted.covered_by_algorithms.count(), 0) + + def test_reassigning_the_category_map_drops_stale_coverage(self): + """When an algorithm's category map is swapped for one that no longer lists + a taxon, the hook-triggered refresh removes that taxon's coverage (assuming + no other algorithm still covers it).""" + taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + map_v1 = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + algorithm = Algorithm.objects.create(name="Test Classifier", version=1, category_map=map_v1) + taxon.refresh_from_db() + self.assertTrue(taxon.has_model_coverage) + + map_v2 = AlgorithmCategoryMap.objects.create( + labels=["Some Other Species"], data=[{"index": 0, "label": "Some Other Species"}] + ) + algorithm.category_map = map_v2 + algorithm.save() + + taxon.refresh_from_db() + self.assertFalse(taxon.has_model_coverage) + self.assertEqual(taxon.covered_by_algorithms.count(), 0) + + def test_refresh_command_repairs_coverage_the_hook_never_saw(self): + """The full-rebuild management command recomputes coverage for every + algorithm's category map, correcting drift from a write path that bypasses + the per-save hook (e.g. a bulk_update on Algorithm.category_map).""" + taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + category_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + algorithm = Algorithm.objects.create(name="Test Classifier", version=1) + # bulk_update never calls Algorithm.save(), so the hook never fires. + Algorithm.objects.filter(pk=algorithm.pk).update(category_map=category_map) + taxon.refresh_from_db() + self.assertFalse(taxon.has_model_coverage) + + call_command("refresh_taxon_model_coverage") + + taxon.refresh_from_db() + self.assertTrue(taxon.has_model_coverage) + algorithm.refresh_from_db() + self.assertIn(taxon, algorithm.covered_taxa.all()) + + def test_covering_algorithms_are_reachable_from_the_taxon(self): + """taxon.covered_by_algorithms.all() names the algorithm(s) that cover a + taxon, so callers can show which model knows a species, not just whether + one does.""" + taxon = Taxon.objects.create(name="Covered Species", rank=TaxonRank.SPECIES.name) + category_map = AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + + algorithm = Algorithm.objects.create(name="Test Classifier", version=1, category_map=category_map) + + self.assertEqual(list(taxon.covered_by_algorithms.all()), [algorithm]) + + +class ApplyModelCoverageDryRunTest(TestCase): + def test_dry_run_partition_is_read_only(self): + """apply_model_coverage's dry_run path checks unsaved stand-in taxa against + current category map labels without writing anything, and still correctly + partitions covered vs. uncovered.""" + AlgorithmCategoryMap.objects.create( + labels=["Covered Species"], data=[{"index": 0, "label": "Covered Species"}] + ) + merged = [ + MergedSpecies( + scientific_name="Covered Species", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ), + MergedSpecies( + scientific_name="Uncovered Species", + rank="SPECIES", + gbif_taxon_key=None, + inat_taxon_id=None, + sources={"gbif_gadm"}, + observation_counts={}, + contributing=[], + ), + ] + + mapping = map_to_taxa(merged, create_missing=True, dry_run=True) + coverage = apply_model_coverage(mapping, dry_run=True) + + self.assertEqual({t.name for t in coverage.covered}, {"Covered Species"}) + self.assertEqual({t.name for t in coverage.uncovered}, {"Uncovered Species"}) + # dry_run must not persist either stand-in Taxon. + self.assertFalse(Taxon.objects.filter(name__in=["Covered Species", "Uncovered Species"]).exists()) + + +class _FakeReverseSession: + """Stub session for GBIF's /geocode/reverse endpoint — returns a fixed list of + GADM entries at levels 0/1/2, the shape the real endpoint returns for a point.""" + + def __init__(self, items): + self.items = items + + def get(self, url, params=None, timeout=None): + return _FakeResponse(self.items) + + +class ReverseGeocodeGADMTest(TestCase): + ITEMS = [ + {"id": "USA", "type": "Political"}, + {"id": "USA.46_1", "type": "Political"}, + {"id": "USA.46.14_1", "type": "Political"}, + ] + + def test_picks_level1_gid(self): + session = _FakeReverseSession(self.ITEMS) + self.assertEqual(reverse_geocode_gadm(44.26, -72.58, level=1, session=session), "USA.46_1") + + def test_picks_level2_gid(self): + session = _FakeReverseSession(self.ITEMS) + self.assertEqual(reverse_geocode_gadm(44.26, -72.58, level=2, session=session), "USA.46.14_1") + + def test_none_when_no_polygon_at_level(self): + # Only the bare country (level 0) is returned — no level-1 polygon contains the point. + session = _FakeReverseSession([{"id": "USA", "type": "Political"}]) + self.assertIsNone(reverse_geocode_gadm(0.0, 0.0, level=1, session=session)) + + +class DeriveRegionForProjectTest(TestCase): + def test_uses_first_located_deployment(self): + project = Project.objects.create(name="Derive P", create_defaults=False) + Deployment.objects.create(name="no coords", project=project) + Deployment.objects.create(name="located", project=project, latitude=44.26, longitude=-72.58) + seen = {} + + def geocoder(latitude, longitude, level=1): + seen["args"] = (latitude, longitude, level) + return "USA.46_1" + + result = derive_region_for_project(project, geocoder=geocoder) + self.assertEqual(result, (RegionSource.GBIF_GADM.value, "USA.46_1")) + self.assertEqual(seen["args"], (44.26, -72.58, 1)) + + def test_none_without_located_deployment(self): + project = Project.objects.create(name="Derive P", create_defaults=False) + Deployment.objects.create(name="no coords", project=project) + self.assertIsNone(derive_region_for_project(project, geocoder=lambda *a, **k: "USA.46_1")) + + def test_none_when_geocoder_finds_no_region(self): + project = Project.objects.create(name="Derive P", create_defaults=False) + Deployment.objects.create(name="located", project=project, latitude=1.0, longitude=2.0) + self.assertIsNone(derive_region_for_project(project, geocoder=lambda *a, **k: None)) + + +def _fake_result(**overrides): + fields = dict( + region_code="USA.46_1", + saved_list_size=1, + model_covered=1, + regional_no_model_coverage=0, + created_taxa=0, + already_in_db=1, + regional_total=1, + dry_run=True, + ) + fields.update(overrides) + return mock.Mock(**fields) + + +class GenerateRegionalTaxaListCommandTest(TestCase): + """The command is a thin wrapper — these pin the arg wiring and the two guards, + with the service itself mocked (its behaviour is covered elsewhere).""" + + @mock.patch("ami.main.services.regional_taxa.generate_regional_taxa_list") + def test_single_project_passes_parsed_args(self, mock_generate): + mock_generate.return_value = _fake_result() + project = Project.objects.create(name="Cmd P", create_defaults=False) + call_command( + "generate_regional_taxa_list", + "--project", + str(project.pk), + "--region-code", + "USA.46_1", + "--include-uncovered", + "--dry-run", + ) + mock_generate.assert_called_once() + kwargs = mock_generate.call_args.kwargs + self.assertEqual(kwargs["region_code"], "USA.46_1") + self.assertEqual(kwargs["project"], project) + self.assertTrue(kwargs["include_uncovered"]) + self.assertTrue(kwargs["dry_run"]) + + def test_requires_region_code_without_all_projects(self): + with self.assertRaises(CommandError): + call_command("generate_regional_taxa_list", "--dry-run") + + def test_all_projects_rejects_explicit_region_code(self): + with self.assertRaises(CommandError): + call_command("generate_regional_taxa_list", "--all-projects", "--region-code", "USA.46_1") + + @mock.patch("ami.main.services.regional_taxa.derive_region_for_project") + @mock.patch("ami.main.services.regional_taxa.generate_regional_taxa_list") + def test_all_projects_generates_for_each_derived_region(self, mock_generate, mock_derive): + mock_generate.return_value = _fake_result() + mock_derive.return_value = (RegionSource.GBIF_GADM.value, "USA.46_1") + Project.objects.create(name="Cmd P1", create_defaults=False) + Project.objects.create(name="Cmd P2", create_defaults=False) + call_command("generate_regional_taxa_list", "--all-projects", "--dry-run") + self.assertEqual(mock_generate.call_count, Project.objects.count()) + + @mock.patch("ami.main.services.regional_taxa.derive_region_for_project") + @mock.patch("ami.main.services.regional_taxa.generate_regional_taxa_list") + def test_all_projects_skips_projects_without_region(self, mock_generate, mock_derive): + mock_derive.return_value = None + Project.objects.create(name="Cmd P1", create_defaults=False) + call_command("generate_regional_taxa_list", "--all-projects", "--dry-run") + mock_generate.assert_not_called() + + +class GenerateRegionalTaxaListTaskTest(TestCase): + """The task links the generated list to the right owner (project vs. site) and is + the slow-fetch-off-the-request-path surface the admin actions enqueue.""" + + @mock.patch("ami.main.services.regional_taxa.generate_regional_taxa_list") + def test_links_list_to_project_by_default(self, mock_generate): + project = Project.objects.create(name="Task P", create_defaults=False) + taxa_list = TaxaList.objects.create(name="Region list") + mock_generate.return_value = _fake_result(taxa_list_id=taxa_list.pk, saved_list_size=3) + generate_regional_taxa_list_task( + project_id=project.pk, region_source=RegionSource.GBIF_GADM.value, region_code="USA.46_1" + ) + project.refresh_from_db() + self.assertEqual(project.default_taxa_list_id, taxa_list.pk) + + @mock.patch("ami.main.services.regional_taxa.generate_regional_taxa_list") + def test_links_list_to_site_when_site_scoped(self, mock_generate): + project = Project.objects.create(name="Task P", create_defaults=False) + site = Site.objects.create(name="Site A", project=project) + taxa_list = TaxaList.objects.create(name="Region list") + mock_generate.return_value = _fake_result(taxa_list_id=taxa_list.pk) + generate_regional_taxa_list_task( + project_id=project.pk, + region_source=RegionSource.GBIF_GADM.value, + region_code="USA.46_1", + site_id=site.pk, + ) + site.refresh_from_db() + project.refresh_from_db() + self.assertEqual(site.taxa_list_id, taxa_list.pk) + self.assertIsNone(project.default_taxa_list_id) + + +def _admin_request(): + request = RequestFactory().post("/admin/") + setattr(request, "session", {}) + request._messages = FallbackStorage(request) + return request + + +class RegionalTaxaAdminActionTest(TestCase): + """The admin actions are thin: enqueue the task only for rows with a region set, + passing the right scope. These pin that skip logic and the site-scope kwarg.""" + + @mock.patch("ami.main.admin.generate_regional_taxa_list_task") + def test_project_action_enqueues_only_configured_rows(self, mock_task): + from django.contrib.admin.sites import site as admin_site + + from ami.main.admin import ProjectAdmin + + configured = Project.objects.create( + name="configured", + region_source=RegionSource.GBIF_GADM.value, + region_code="USA.46_1", + create_defaults=False, + ) + Project.objects.create(name="no-region", create_defaults=False) + ProjectAdmin(Project, admin_site).generate_regional_taxa_list_action(_admin_request(), Project.objects.all()) + mock_task.delay.assert_called_once() + self.assertEqual(mock_task.delay.call_args.kwargs["project_id"], configured.pk) + + @mock.patch("ami.main.admin.generate_regional_taxa_list_task") + def test_site_action_passes_site_scope(self, mock_task): + from django.contrib.admin.sites import site as admin_site + + from ami.main.admin import SiteAdmin + + project = Project.objects.create(name="p", create_defaults=False) + site = Site.objects.create( + name="s", project=project, region_source=RegionSource.GBIF_GADM.value, region_code="USA.46_1" + ) + SiteAdmin(Site, admin_site).generate_regional_taxa_list_action(_admin_request(), Site.objects.all()) + mock_task.delay.assert_called_once() + kwargs = mock_task.delay.call_args.kwargs + self.assertEqual(kwargs["site_id"], site.pk) + self.assertEqual(kwargs["project_id"], project.pk) + + +class GenerateRegionalTaxaListEndpointTest(APITestCase): + """POST /projects/{id}/generate-regional-taxa-list/ — permission matrix, body + validation, region derivation, and that a valid call enqueues the task.""" + + def setUp(self): + self.project = Project.objects.create(name="API Project", draft=False, create_defaults=False) + self.editor = User.objects.create_user(email="editor@example.com", password="pw") + assign_perm(Project.Permissions.UPDATE_PROJECT, self.editor, self.project) + self.viewer = User.objects.create_user(email="viewer@example.com", password="pw") + self.url = f"/api/v2/projects/{self.project.pk}/generate-regional-taxa-list/" + + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_editor_queues_with_explicit_region(self, mock_task): + self.client.force_authenticate(self.editor) + response = self.client.post( + self.url, {"region_source": RegionSource.GBIF_GADM.value, "region_code": "USA.46_1"}, format="json" + ) + self.assertEqual(response.status_code, 202) + mock_task.delay.assert_called_once() + self.assertEqual(mock_task.delay.call_args.kwargs["region_code"], "USA.46_1") + self.assertEqual(mock_task.delay.call_args.kwargs["project_id"], self.project.pk) + + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_non_editor_is_forbidden(self, mock_task): + self.client.force_authenticate(self.viewer) + response = self.client.post(self.url, {"region_code": "USA.46_1"}, format="json") + self.assertEqual(response.status_code, 403) + mock_task.delay.assert_not_called() + + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_anonymous_is_denied(self, mock_task): + response = self.client.post(self.url, {"region_code": "USA.46_1"}, format="json") + self.assertIn(response.status_code, (401, 403)) + mock_task.delay.assert_not_called() + + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_invalid_region_source_is_400(self, mock_task): + self.client.force_authenticate(self.editor) + response = self.client.post(self.url, {"region_source": "bogus", "region_code": "X"}, format="json") + self.assertEqual(response.status_code, 400) + mock_task.delay.assert_not_called() + + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_missing_and_underivable_region_is_400(self, mock_task): + # No region_code in the body and no located deployment to derive one from. + self.client.force_authenticate(self.editor) + response = self.client.post(self.url, {}, format="json") + self.assertEqual(response.status_code, 400) + mock_task.delay.assert_not_called() + + @mock.patch("ami.main.services.regional_taxa.derive_region_for_project") + @mock.patch("ami.main.api.views.generate_regional_taxa_list_task") + def test_region_is_derived_when_omitted(self, mock_task, mock_derive): + mock_derive.return_value = (RegionSource.GBIF_GADM.value, "USA.46_1") + self.client.force_authenticate(self.editor) + response = self.client.post(self.url, {}, format="json") + self.assertEqual(response.status_code, 202) + self.assertEqual(mock_task.delay.call_args.kwargs["region_code"], "USA.46_1") diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py index 0e9df4609..fe5ba6eb1 100644 --- a/ami/ml/models/algorithm.py +++ b/ami/ml/models/algorithm.py @@ -247,11 +247,24 @@ class Meta: ] def save(self, *args, **kwargs): + previous_category_map_id = None + if self.pk: + previous_category_map_id = ( + type(self).objects.filter(pk=self.pk).values_list("category_map_id", flat=True).first() + ) if not self.version_name: self.version_name = f"{self.version}" if not self.key: self.key = f"{slugify(self.name)}-{self.version}" super().save(*args, **kwargs) + if self.category_map_id and self.category_map_id != previous_category_map_id: + # A newly-linked category map means this algorithm's predictable taxa + # just changed (or became known for the first time); keep the persisted + # Taxon.covered_by_algorithms / has_model_coverage relationship in sync. + # See ami.main.services.taxon_coverage. + from ami.main.services.taxon_coverage import refresh_algorithm_coverage + + refresh_algorithm_coverage(self) def category_count(self) -> int | None: """ diff --git a/ami/ml/post_processing/admin/class_masking_form.py b/ami/ml/post_processing/admin/class_masking_form.py index 6e2f5f884..f7babc3da 100644 --- a/ami/ml/post_processing/admin/class_masking_form.py +++ b/ami/ml/post_processing/admin/class_masking_form.py @@ -22,12 +22,27 @@ class ClassMaskingActionForm(BasePostProcessingActionForm): label="Source classifier", help_text="The classification algorithm whose terminal predictions will be re-scored.", ) + taxa_list_mode = forms.ChoiceField( + choices=( + ("explicit", "Use the selected taxa list"), + ("auto", "Resolve automatically from the occurrence's site (then its project)"), + ), + required=False, + initial="explicit", + label="Taxa list source", + help_text=( + "Automatic resolution uses the region-configured list on the scope's site, " + "falling back to the project's default; it is a no-op if neither is set." + ), + ) taxa_list_id = forms.ModelChoiceField( queryset=TaxaList.objects.all().order_by("name"), + required=False, label="Taxa list to keep", help_text=( "Classes whose taxon is not in this list are masked out; each " - "classification's softmax is renormalised over the classes that remain." + "classification's softmax is renormalised over the classes that remain. " + "Leave blank when using automatic resolution." ), ) reweight = forms.BooleanField( @@ -68,9 +83,21 @@ def _algorithms_for_scope(scope_queryset): .order_by("name") ) + def clean(self): + cleaned = super().clean() + mode = cleaned.get("taxa_list_mode") or "explicit" + cleaned["taxa_list_mode"] = mode + if mode == "explicit" and not cleaned.get("taxa_list_id"): + self.add_error("taxa_list_id", "Select a taxa list, or switch the source to automatic resolution.") + return cleaned + def to_config(self) -> dict: - return { + mode = self.cleaned_data["taxa_list_mode"] + config = { "algorithm_id": self.cleaned_data["algorithm_id"].pk, - "taxa_list_id": self.cleaned_data["taxa_list_id"].pk, "reweight": self.cleaned_data["reweight"], + "taxa_list_mode": mode, } + if mode == "explicit": + config["taxa_list_id"] = self.cleaned_data["taxa_list_id"].pk + return config diff --git a/ami/ml/post_processing/class_masking.py b/ami/ml/post_processing/class_masking.py index ffc77f346..f8c982514 100644 --- a/ami/ml/post_processing/class_masking.py +++ b/ami/ml/post_processing/class_masking.py @@ -22,7 +22,14 @@ class ClassMaskingConfig(pydantic.BaseModel): source_image_collection_id: int | None = None occurrence_id: int | None = None # The taxa list to keep: classes whose taxon is not in this list are masked out. - taxa_list_id: int + # Required in "explicit" mode; omitted in "auto" mode, where the list is resolved + # from the scope's site (then its project's default) so a pipeline can apply + # masking without an operator choosing a list each run. + taxa_list_id: int | None = None + # "explicit": use taxa_list_id directly. "auto": resolve the list per the scope's + # site/project (see ClassMaskingTask._resolve_auto_taxa_list); when nothing is + # configured the task is a safe no-op, so auto masking can be enabled by default. + taxa_list_mode: str = "explicit" # The source classifier whose terminal classifications are re-scored. algorithm_id: int # When True (default), renormalise the kept classes' scores to sum to 1 after @@ -31,10 +38,18 @@ class ClassMaskingConfig(pydantic.BaseModel): reweight: bool = True @pydantic.root_validator(skip_on_failure=True) - def _exactly_one_scope(cls, values: dict) -> dict: + def _validate_scope_and_mode(cls, values: dict) -> dict: scopes = [values.get("source_image_collection_id"), values.get("occurrence_id")] if sum(s is not None for s in scopes) != 1: raise ValueError("Provide exactly one of source_image_collection_id or occurrence_id") + + mode = values.get("taxa_list_mode") + if mode not in ("explicit", "auto"): + raise ValueError("taxa_list_mode must be 'explicit' or 'auto'") + if mode == "explicit" and values.get("taxa_list_id") is None: + raise ValueError("taxa_list_id is required when taxa_list_mode is 'explicit'") + if mode == "auto" and values.get("taxa_list_id") is not None: + raise ValueError("taxa_list_id must be omitted when taxa_list_mode is 'auto'") return values class Config: @@ -262,6 +277,42 @@ def _get_or_create_masking_algorithm( algorithm.save(update_fields=["category_map"]) return algorithm + def _resolve_auto_taxa_list(self, config: ClassMaskingConfig) -> TaxaList | None: + """Resolve the taxa list for auto mode from the scope's configured region. + + The ladder mirrors the design in issue #1364: an occurrence prefers its + site's list, then falls back to its project's default; a collection resolves + at the project level (a collection is not tied to a single site). Returns None + when nothing is configured, which the caller treats as a no-op rather than an + error — that is what makes auto masking safe to enable before a project has + set up a region. The region_code → generate rungs of the ladder are handled by + the generation surfaces (command/admin/API), not inline here, so a masking run + never triggers a slow external fetch. + """ + if config.occurrence_id is not None: + try: + occurrence = Occurrence.objects.select_related( + "deployment__research_site__taxa_list", "project__default_taxa_list" + ).get(pk=config.occurrence_id) + except Occurrence.DoesNotExist: + raise ValueError(f"Occurrence {config.occurrence_id} not found") + site = occurrence.deployment.research_site if occurrence.deployment else None + if site is not None and site.taxa_list_id: + return site.taxa_list + if occurrence.project is not None and occurrence.project.default_taxa_list_id: + return occurrence.project.default_taxa_list + return None + + try: + collection = SourceImageCollection.objects.select_related("project__default_taxa_list").get( + pk=config.source_image_collection_id + ) + except SourceImageCollection.DoesNotExist: + raise ValueError(f"SourceImageCollection {config.source_image_collection_id} not found") + if collection.project is not None and collection.project.default_taxa_list_id: + return collection.project.default_taxa_list + return None + def _scoped_classifications( self, config: ClassMaskingConfig, source_algorithm: Algorithm ) -> tuple[QuerySet[Classification], str]: @@ -302,10 +353,23 @@ def run(self) -> None: source_algorithm = Algorithm.objects.get(pk=config.algorithm_id) except Algorithm.DoesNotExist: raise ValueError(f"Algorithm {config.algorithm_id} not found") - try: - taxa_list = TaxaList.objects.get(pk=config.taxa_list_id) - except TaxaList.DoesNotExist: - raise ValueError(f"TaxaList {config.taxa_list_id} not found") + if config.taxa_list_mode == "auto": + taxa_list = self._resolve_auto_taxa_list(config) + if taxa_list is None: + self.logger.info( + "Class masking (auto): no taxa list is configured for this scope's site or project; " + "there is nothing to mask, so the task is a no-op." + ) + self.report_stage_metrics( + {"classifications_checked": 0, "classifications_masked": 0, "occurrences_updated": 0} + ) + return + self.logger.info(f"Class masking (auto): resolved taxa list {taxa_list.pk} ({taxa_list.name})") + else: + try: + taxa_list = TaxaList.objects.get(pk=config.taxa_list_id) + except TaxaList.DoesNotExist: + raise ValueError(f"TaxaList {config.taxa_list_id} not found") if not source_algorithm.category_map: raise ValueError(f"Algorithm '{source_algorithm.name}' has no category map; cannot mask classes.") diff --git a/ami/ml/post_processing/tests/test_class_masking_auto.py b/ami/ml/post_processing/tests/test_class_masking_auto.py new file mode 100644 index 000000000..5f7d7f8c1 --- /dev/null +++ b/ami/ml/post_processing/tests/test_class_masking_auto.py @@ -0,0 +1,109 @@ +"""Auto-mode class masking (#1364, Phase 3). + +Class masking can resolve its taxa list automatically from the scope's configured +region instead of an operator picking one each run: an occurrence prefers its site's +list, then its project's default; a collection resolves at the project level. When +nothing is configured the run is a safe no-op, so a pipeline can enable masking before +a project has set up a region. These tests pin the config validation, the resolution +ladder, and the no-op path — the masking maths itself is covered in test_class_masking. +""" + +import pydantic +from django.test import TestCase + +from ami.main.models import Deployment, Occurrence, Project, Site, SourceImageCollection, TaxaList +from ami.ml.models.algorithm import Algorithm, AlgorithmTaskType +from ami.ml.post_processing.class_masking import ClassMaskingConfig, ClassMaskingTask + + +class ClassMaskingConfigValidationTest(TestCase): + def _config(self, **overrides): + values = dict(occurrence_id=1, algorithm_id=1) + values.update(overrides) + return ClassMaskingConfig(**values) + + def test_default_mode_is_explicit_and_requires_a_list(self): + with self.assertRaises(pydantic.ValidationError): + self._config() # explicit by default, no taxa_list_id + self.assertEqual(self._config(taxa_list_id=5).taxa_list_mode, "explicit") + + def test_explicit_with_list_is_valid(self): + config = self._config(taxa_list_mode="explicit", taxa_list_id=5) + self.assertEqual(config.taxa_list_id, 5) + + def test_auto_must_omit_the_list(self): + with self.assertRaises(pydantic.ValidationError): + self._config(taxa_list_mode="auto", taxa_list_id=5) + + def test_auto_without_a_list_is_valid(self): + config = self._config(taxa_list_mode="auto") + self.assertEqual(config.taxa_list_mode, "auto") + self.assertIsNone(config.taxa_list_id) + + def test_unknown_mode_is_rejected(self): + with self.assertRaises(pydantic.ValidationError): + self._config(taxa_list_mode="regional", taxa_list_id=5) + + +class ClassMaskingAutoResolutionTest(TestCase): + def setUp(self): + self.project = Project.objects.create(name="Auto Project", create_defaults=False) + self.algorithm = Algorithm.objects.create( + name="Source classifier", key="auto_src_clf", task_type=AlgorithmTaskType.CLASSIFICATION.value + ) + + def _task(self, **config): + values = dict(taxa_list_mode="auto", algorithm_id=self.algorithm.pk) + values.update(config) + return ClassMaskingTask(**values) + + def _occurrence(self, *, site=None): + deployment = Deployment.objects.create(name="D", project=self.project, research_site=site) + return Occurrence.objects.create(project=self.project, deployment=deployment) + + def test_occurrence_prefers_its_sites_list(self): + site_list = TaxaList.objects.create(name="Site list") + self.project.default_taxa_list = TaxaList.objects.create(name="Project list") + self.project.save() + site = Site.objects.create(name="S", project=self.project, taxa_list=site_list) + occurrence = self._occurrence(site=site) + + task = self._task(occurrence_id=occurrence.pk) + self.assertEqual(task._resolve_auto_taxa_list(task.config), site_list) + + def test_occurrence_falls_back_to_project_default(self): + project_list = TaxaList.objects.create(name="Project list") + self.project.default_taxa_list = project_list + self.project.save() + site = Site.objects.create(name="S", project=self.project) # no taxa_list on the site + occurrence = self._occurrence(site=site) + + task = self._task(occurrence_id=occurrence.pk) + self.assertEqual(task._resolve_auto_taxa_list(task.config), project_list) + + def test_occurrence_resolves_to_none_when_unconfigured(self): + occurrence = self._occurrence(site=None) + task = self._task(occurrence_id=occurrence.pk) + self.assertIsNone(task._resolve_auto_taxa_list(task.config)) + + def test_collection_uses_project_default(self): + project_list = TaxaList.objects.create(name="Project list") + self.project.default_taxa_list = project_list + self.project.save() + collection = SourceImageCollection.objects.create(name="C", project=self.project) + + task = self._task(source_image_collection_id=collection.pk) + self.assertEqual(task._resolve_auto_taxa_list(task.config), project_list) + + def test_collection_resolves_to_none_without_a_default(self): + collection = SourceImageCollection.objects.create(name="C", project=self.project) + task = self._task(source_image_collection_id=collection.pk) + self.assertIsNone(task._resolve_auto_taxa_list(task.config)) + + def test_run_is_a_noop_when_auto_resolves_to_nothing(self): + occurrence = self._occurrence(site=None) + self._task(occurrence_id=occurrence.pk).run() + # No masking algorithm is created when there is no list to apply. + self.assertFalse( + Algorithm.objects.filter(key__startswith=f"{self.algorithm.key}_filtered_by_taxa_list_").exists() + )