Skip to content
Open
23 changes: 23 additions & 0 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from typing import Any

from django.contrib import admin
Expand Down Expand Up @@ -668,10 +669,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI

self.message_user(request, f"Queued Small Size Filter for {queryset.count()} capture set(s). Jobs: {jobs}")

@admin.action(description="Run Occurrence Tracking post-processing task (async)")
def run_tracking(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
from ami.ml.post_processing.tracking_task import DEFAULT_TRACKING_PARAMS

jobs = []
for collection in queryset:
job = Job.objects.create(
name=f"Post-processing: Tracking on Capture Set {collection.pk}",
project=collection.project,
source_image_collection=collection,
job_type_key="post_processing",
params={
"task": "tracking",
"config": dataclasses.asdict(DEFAULT_TRACKING_PARAMS),
},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Tracking for {queryset.count()} capture set(s). Jobs: {jobs}")

actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
run_tracking,
]

# Hide images many-to-many field from form. This would list all source images in the database.
Expand Down
14 changes: 14 additions & 0 deletions ami/main/migrations/0084_add_pgvector_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("main", "0083_dedupe_taxalist_names"),
]

operations = [
migrations.RunSQL(
sql="CREATE EXTENSION IF NOT EXISTS vector;",
reverse_sql="DROP EXTENSION IF EXISTS vector;",
Comment thread
mihow marked this conversation as resolved.
Outdated
),
]
20 changes: 20 additions & 0 deletions ami/main/migrations/0085_classification_features_2048.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from django.db import migrations
import pgvector.django.vector


class Migration(migrations.Migration):
dependencies = [
("main", "0084_add_pgvector_extension"),
]

operations = [
migrations.AddField(
model_name="classification",
name="features_2048",
field=pgvector.django.vector.VectorField(
dimensions=2048,
null=True,
help_text="Feature embedding from the model backbone",
),
),
]
23 changes: 23 additions & 0 deletions ami/main/migrations/0086_detection_next_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0085_classification_features_2048"),
]

operations = [
migrations.AddField(
model_name="detection",
name="next_detection",
field=models.OneToOneField(
blank=True,
help_text="The detection that follows this one in the tracking sequence.",
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="previous_detection",
to="main.detection",
),
),
]
15 changes: 15 additions & 0 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from io import BytesIO
from typing import Final, final # noqa: F401

import pgvector.django
import PIL.Image
import pydantic
from django.apps import apps
Expand Down Expand Up @@ -2590,6 +2591,11 @@ class Classification(BaseModel):
null=True,
help_text="The probabilities the model, calibrated by the model maker, likely the softmax output",
)
features_2048 = pgvector.django.VectorField(
dimensions=2048,
null=True,
help_text="Feature embedding from the model backbone",
)
category_map = models.ForeignKey("ml.AlgorithmCategoryMap", on_delete=models.PROTECT, null=True)

algorithm = models.ForeignKey(
Expand Down Expand Up @@ -2784,6 +2790,15 @@ class Detection(BaseModel):

similarity_vector = models.JSONField(null=True, blank=True)

next_detection = models.OneToOneField(
"self",
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="previous_detection",
help_text="The detection that follows this one in the tracking sequence.",
)

# For type hints
classifications: models.QuerySet["Classification"]
source_image_id: int
Expand Down
6 changes: 5 additions & 1 deletion ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def create_classification(

if existing_classification:
# @TODO remove this after all existing classifications have been updated (added 2024-12-20)
NEW_FIELDS = ["logits", "scores", "terminal", "category_map"]
NEW_FIELDS = ["logits", "scores", "terminal", "category_map", "features_2048"]
logger.debug(
"Duplicate classification found: "
f"{existing_classification.taxon} from {existing_classification.algorithm}, "
Expand All @@ -705,6 +705,9 @@ def create_classification(
if field == "category_map":
# Use the foreign key from the classification algorithm
setattr(existing_classification, field, classification_algo.category_map)
elif field == "features_2048":
# The pipeline response carries this as `features`; the DB column is `features_2048`.
setattr(existing_classification, field, classification_resp.features)
else:
# Get the value from the classification response
setattr(existing_classification, field, getattr(classification_resp, field))
Expand All @@ -722,6 +725,7 @@ def create_classification(
timestamp=classification_resp.timestamp or now(),
logits=classification_resp.logits,
scores=classification_resp.scores,
features_2048=classification_resp.features,
terminal=classification_resp.terminal,
Comment thread
mihow marked this conversation as resolved.
category_map=classification_algo.category_map,
)
Expand Down
2 changes: 2 additions & 0 deletions ami/ml/post_processing/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Registry of available post-processing tasks
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
from ami.ml.post_processing.tracking_task import TrackingTask

POSTPROCESSING_TASKS = {
SmallSizeFilterTask.key: SmallSizeFilterTask,
TrackingTask.key: TrackingTask,
}


Expand Down
Empty file.
99 changes: 99 additions & 0 deletions ami/ml/post_processing/tests/test_tracking_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from collections import defaultdict

import numpy as np
from django.test import TestCase
from django.utils import timezone

from ami.main.models import Classification, Detection, Occurrence
from ami.ml.models import Algorithm
from ami.ml.post_processing.tracking_task import DEFAULT_TRACKING_PARAMS, assign_occurrences_by_tracking_images
from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project

logger = logging.getLogger(__name__)


class TestTracking(TestCase):
def setUp(self) -> None:
self.project, self.deployment = setup_test_project(reuse=False)
# 1 night, 5 captures spaced 1 minute apart so they group into one event.
create_captures(deployment=self.deployment, num_nights=1, images_per_night=5, interval_minutes=1)
create_taxa(self.project)
create_occurrences(deployment=self.deployment, num=6)

self.event = self.project.events.first()
assert self.event is not None
self.source_images = list(self.event.captures.order_by("timestamp"))

# Source images need dimensions for the cost function.
for img in self.source_images:
if not img.width or not img.height:
img.width = 4096
img.height = 2160
img.save(update_fields=["width", "height"])

self.algorithm = self._assign_mock_features_to_occurrence_detections(self.event)

# Capture ground-truth groupings so we can compare after re-tracking.
self.ground_truth_groups = defaultdict(set)
for occ in Occurrence.objects.filter(event=self.event):
for det_id in Detection.objects.filter(occurrence=occ).values_list("id", flat=True):
self.ground_truth_groups[occ.pk].add(det_id)

Detection.objects.filter(source_image__event=self.event).update(next_detection=None)

def _assign_mock_features_to_occurrence_detections(
self, event, algorithm_name: str = "MockTrackingAlgorithm"
) -> Algorithm:
algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_name, key="mock-tracking-algo")
rng = np.random.default_rng(seed=42)

for occurrence in event.occurrences.all():
base_vector = rng.random(2048)
for det in occurrence.detections.all():
noisy = base_vector + rng.normal(0, 0.001, size=2048)
Classification.objects.update_or_create(
detection=det,
algorithm=algorithm,
defaults={
"timestamp": timezone.now(),
"features_2048": noisy.tolist(),
"terminal": True,
"score": 1.0,
},
)
return algorithm

def test_tracking_reproduces_occurrence_groups(self):
# Wipe existing chain links and occurrences so tracking has to rebuild them.
for det in Detection.objects.filter(source_image__event=self.event):
det.occurrence = None
det.next_detection = None
det.save()
Occurrence.objects.filter(event=self.event).delete()

assign_occurrences_by_tracking_images(
event=self.event,
logger=logger,
algorithm=self.algorithm,
params=DEFAULT_TRACKING_PARAMS,
)

new_groups = {
occ.pk: set(Detection.objects.filter(occurrence=occ).values_list("id", flat=True))
for occ in Occurrence.objects.filter(event=self.event)
}

self.assertEqual(
len(new_groups),
len(self.ground_truth_groups),
f"Expected {len(self.ground_truth_groups)} groups, got {len(new_groups)}",
)

gt_values = list(self.ground_truth_groups.values())
for new_set in new_groups.values():
self.assertIn(
new_set,
gt_values,
f"Reconstructed group {new_set} does not match any ground-truth group",
)
Loading
Loading