diff --git a/ami/jobs/migrations/0021_joblog.py b/ami/jobs/migrations/0021_joblog.py new file mode 100644 index 000000000..f88fcd4b3 --- /dev/null +++ b/ami/jobs/migrations/0021_joblog.py @@ -0,0 +1,32 @@ +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0020_schedule_job_monitoring_beat_tasks"), + ] + + operations = [ + migrations.CreateModel( + name="JobLog", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("level", models.CharField(max_length=20)), + ("message", models.TextField()), + ("context", models.JSONField(blank=True, default=dict)), + ( + "job", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="log_entries", to="jobs.job" + ), + ), + ], + options={ + "ordering": ["-created_at", "-pk"], + "indexes": [models.Index(fields=["job", "-created_at"], name="jobs_joblog_job_id_e4aa59_idx")], + }, + ), + ] diff --git a/ami/jobs/migrations/0022_alter_job_logs_help_text.py b/ami/jobs/migrations/0022_alter_job_logs_help_text.py new file mode 100644 index 000000000..7cb45405e --- /dev/null +++ b/ami/jobs/migrations/0022_alter_job_logs_help_text.py @@ -0,0 +1,22 @@ +import ami.jobs.models +import django_pydantic_field.fields +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0021_joblog"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="logs", + field=django_pydantic_field.fields.PydanticSchemaField( + config=None, + default=ami.jobs.models.JobLogs, + help_text="DEPRECATED: read-only fallback for pre-#1259 jobs. Use the JobLog table for new writes.", + schema=ami.jobs.models.JobLogs, + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 4111034f8..3669b65c4 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -322,13 +322,80 @@ class JobLogs(pydantic.BaseModel): stderr: list[str] = pydantic.Field(default_factory=list, alias="stderr", title="Error messages") +class JobLog(BaseModel): + """Append-only per-job log row. + + Replaces the ``jobs_job.logs`` JSON-field UPDATE path that caused row-lock + contention under concurrent async_api load (issue #1256). Each log emit + becomes a cheap INSERT on this child table instead of a refresh+UPDATE of + the shared parent row. Legacy JSON-field logs are still served by the + serializer for jobs created before this table existed. + """ + + project_accessor = "job__project" + + job = models.ForeignKey("Job", on_delete=models.CASCADE, related_name="log_entries") + level = models.CharField(max_length=20) + message = models.TextField() + # Freeform bag for future per-line metadata (stage, worker id, counters, ...) + # without requiring a schema migration. Kept nullable/empty-default so it + # costs nothing on existing rows. + context = models.JSONField(blank=True, default=dict) + + class Meta: + ordering = ["-created_at", "-pk"] + indexes = [models.Index(fields=["job", "-created_at"])] + + +JOB_LOG_LEVELS_STDERR = {"ERROR", "CRITICAL"} +JOB_LOG_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S" +JOB_LOGS_DEFAULT_LIMIT = 1000 +# Hard ceiling on a single read response. Keeps payload size bounded even when +# a caller passes ``?logs_limit=...``. Real pagination ships separately with a +# dedicated ``/jobs/logs/`` endpoint. +JOB_LOGS_MAX_LIMIT = 5000 + + +def _legacy_logs_shape(job: "Job") -> dict[str, list[str]]: + legacy = getattr(job, "logs", None) + return { + "stdout": list(getattr(legacy, "stdout", []) or []), + "stderr": list(getattr(legacy, "stderr", []) or []), + } + + +def serialize_job_logs(job: "Job", *, limit: int = JOB_LOGS_DEFAULT_LIMIT) -> dict[str, list[str]]: + """Return ``{stdout, stderr}`` in the shape the UI already parses. + + Reads joined ``JobLog`` rows first (newest-first, capped at ``limit`` per + request — there is no per-job storage cap; the data integrity check + framework handles retention). Jobs created before the table existed and + jobs written while ``JOB_LOG_PERSIST_ENABLED=False`` have no rows and fall + back to the legacy ``jobs_job.logs`` JSON column so their UI log panel + stays populated. + """ + entries = list( + JobLog.objects.filter(job_id=job.pk) + .only("created_at", "level", "message") + .order_by("-created_at", "-pk")[:limit] + ) + if entries: + return { + "stdout": [ + f"[{entry.created_at.strftime(JOB_LOG_TIMESTAMP_FORMAT)}] {entry.level} {entry.message}" + for entry in entries + ], + "stderr": [entry.message for entry in entries if entry.level in JOB_LOG_LEVELS_STDERR], + } + + return _legacy_logs_shape(job) + + class JobLogHandler(logging.Handler): """ Class for handling logs from a job and writing them to the job instance. """ - max_log_length = 1000 - def __init__(self, job: "Job", *args, **kwargs): self.job = job super().__init__(*args, **kwargs) @@ -337,41 +404,24 @@ def emit(self, record: logging.LogRecord): # Log to the current app logger (container stdout). logger.log(record.levelno, self.format(record)) - # Gated by ``JOB_LOG_PERSIST_ENABLED`` (default True). Persisting every - # log line to ``jobs_job.logs`` becomes a row-lock contention point - # under concurrent async_api load — each call triggers - # ``UPDATE jobs_job SET logs = ...`` on the shared job row, and inside - # ``ATOMIC_REQUESTS`` a single batched ``/result`` POST stacks N such - # UPDATEs in one tx, blocking every ML worker on the same row for the - # duration of the request. Deployments hitting that pattern can set the - # flag to False to short-circuit here until PR #1259 lands an - # append-only ``JobLog`` child table. See issue #1256. + # Escape hatch: when False, skip the per-job DB write entirely. Container + # stdout still captures every line above, so ops observability is + # unchanged; only the per-job UI log view loses new entries for the + # duration the flag is off. Default is True. See issue #1256. if not getattr(settings, "JOB_LOG_PERSIST_ENABLED", True): return - # Write to the logs field on the job instance. - # Refresh from DB first to reduce the window for concurrent overwrites — each - # worker holds its own stale in-memory copy of `logs`, so without a refresh the - # last writer always wins and earlier entries are silently dropped. - # @TODO consider saving logs to the database periodically rather than on every log + # Append-only insert on the JobLog child table. Unlike the legacy + # jobs_job.logs JSONB update path, this does not contend with + # _update_job_progress on the parent row. try: - self.job.refresh_from_db(fields=["logs"]) - timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - msg = f"[{timestamp}] {record.levelname} {self.format(record)}" - if msg not in self.job.logs.stdout: - self.job.logs.stdout.insert(0, msg) - - # Write a simpler copy of any errors to the errors field - if record.levelno >= logging.ERROR: - if record.message not in self.job.logs.stderr: - self.job.logs.stderr.insert(0, record.message) - - if len(self.job.logs.stdout) > self.max_log_length: - self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length] - - self.job.save(update_fields=["logs"], update_progress=False) + JobLog.objects.create( + job_id=self.job.pk, + level=record.levelname, + message=self.format(record), + ) except Exception as e: - logger.error(f"Failed to save logs for job #{self.job.pk}: {e}") + logger.error(f"Failed to save log for job #{self.job.pk}: {e}") @dataclass @@ -853,7 +903,15 @@ class Job(BaseModel): # @TODO can we use an Enum or Pydantic model for status? status = models.CharField(max_length=255, default=JobState.CREATED.name, choices=JobState.choices()) progress: JobProgress = SchemaField(JobProgress, default=default_job_progress) - logs: JobLogs = SchemaField(JobLogs, default=JobLogs) + # DEPRECATED: per-line writes moved to the JobLog child table (issue #1256, PR #1259). + # Retained as a read-only fallback so jobs created before the migration still + # surface their stored logs in the UI. Will be dropped in a follow-up after + # the legacy rows are backfilled into JobLog. Do not write to this field. + logs: JobLogs = SchemaField( + JobLogs, + default=JobLogs, + help_text="DEPRECATED: read-only fallback for pre-#1259 jobs. Use the JobLog table for new writes.", + ) params = models.JSONField(null=True, blank=True) result = models.JSONField(null=True, blank=True) task_id = models.CharField(max_length=255, null=True, blank=True) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 74af39ce9..c4b37b92a 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -23,3 +23,14 @@ class QueuedTaskAcknowledgment(pydantic.BaseModel): required=False, type=bool, ) + +logs_limit_param = OpenApiParameter( + name="logs_limit", + description=( + "Max number of JobLog rows to include in the ``logs`` field on the detail response. " + "Newest-first. Defaults to 1000, capped at 5000. Pagination over older entries will " + "ship with a dedicated ``/jobs/logs/`` endpoint." + ), + required=False, + type=int, +) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index fed4d58fa..f53199e73 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -1,4 +1,5 @@ from django_pydantic_field.rest_framework import SchemaField +from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from ami.exports.models import DataExport @@ -13,7 +14,7 @@ from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult, ProcessingServiceClientInfo from ami.ml.serializers import PipelineNestedSerializer -from .models import Job, JobLogs, JobProgress, MLJob +from .models import JOB_LOGS_DEFAULT_LIMIT, Job, JobProgress, MLJob, _legacy_logs_shape, serialize_job_logs from .schemas import QueuedTaskAcknowledgment @@ -49,7 +50,7 @@ class JobListSerializer(DefaultSerializer): source_image_single = SourceImageNestedSerializer(read_only=True) data_export = DataExportNestedSerializer(read_only=True) progress = SchemaField(schema=JobProgress, read_only=True) - logs = SchemaField(schema=JobLogs, read_only=True) + logs = serializers.SerializerMethodField() job_type = JobTypeSerializer(read_only=True) # All jobs created from the Jobs UI are ML jobs (datasync, etc. are created for the user) # @TODO Remove this when the UI is updated pass a job type. This should be a required field. @@ -147,6 +148,30 @@ class Meta: "dispatch_mode", ] + @extend_schema_field( + { + "type": "object", + "properties": { + "stdout": {"type": "array", "items": {"type": "string"}, "title": "All messages"}, + "stderr": {"type": "array", "items": {"type": "string"}, "title": "Error messages"}, + }, + "required": ["stdout", "stderr"], + } + ) + def get_logs(self, obj: Job) -> dict[str, list[str]]: + # List responses skip the JobLog query to avoid N+1 — the UI only renders + # logs on the detail page, so returning the (typically empty for new jobs) + # legacy JSON shape is acceptable. Detail responses go to the joined table + # and fall back to the legacy shape for pre-migration jobs. + view = self.context.get("view") + if getattr(view, "action", None) == "list": + return _legacy_logs_shape(obj) + # ``JobViewSet.get_serializer_context`` validates ``?logs_limit=`` and + # puts the cleaned int (or ``None`` when unset) on context, so a bad + # value already 400'd before we got here. + limit = self.context.get("logs_limit") or JOB_LOGS_DEFAULT_LIMIT + return serialize_job_logs(obj, limit=limit) + class JobSerializer(JobListSerializer): # progress = serializers.JSONField(initial=Job.default_progress(), allow_null=False, required=False) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index ee8d833cd..783180ed7 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -8,7 +8,15 @@ from rest_framework.test import APIRequestFactory, APITestCase from ami.base.serializers import reverse_with_params -from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob +from ami.jobs.models import ( + Job, + JobDispatchMode, + JobLog, + JobProgress, + JobState, + MLJob, + SourceImageCollectionPopulateJob, +) from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.models.processing_service import ProcessingService @@ -18,6 +26,10 @@ logger = logging.getLogger(__name__) +def joined_job_log_messages(job: Job) -> str: + return "\n".join(JobLog.objects.filter(job=job).order_by("-created_at", "-pk").values_list("message", flat=True)) + + class TestJobProgress(TestCase): def setUp(self): self.project = Project.objects.create(name="Test project") @@ -392,6 +404,17 @@ def test_list_jobs_with_ids_only(self): # Verify we don't get the full results structure self.assertNotIn("details", data["results"][0]) + def test_list_jobs_with_invalid_cutoff_hours_returns_400(self): + """``?cutoff_hours=abc`` must 400, not 500. Locks in the + ``SingleParamSerializer`` validation pattern in ``get_queryset``.""" + self.client.force_authenticate(user=self.user) + url = reverse_with_params( + "api:job-list", + params={"project_id": self.project.pk, "cutoff_hours": "abc"}, + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, 400) + def test_list_jobs_ids_only_pops_one(self): """`?ids_only=1` without an explicit limit returns one job (pop()-style handoff).""" self._create_job("Test job 2", start_now=False) @@ -645,7 +668,7 @@ def test_tasks_endpoint_logs_fetch_to_job_logger(self): self.assertEqual(resp.status_code, 200) job.refresh_from_db() - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) self.assertIn("Tasks fetched", joined) self.assertIn("requested=2", joined) self.assertIn("delivered=", joined) @@ -666,7 +689,7 @@ def test_tasks_endpoint_logs_early_exit_for_terminal_job(self): self.assertEqual(resp.json(), {"tasks": []}) job.refresh_from_db() - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) self.assertIn("non-active job", joined) self.assertIn(f"status={JobState.SUCCESS}", joined) @@ -705,7 +728,7 @@ def test_result_endpoint_mirrors_queued_log_to_job_logger(self): self.assertEqual(resp.status_code, 200) job.refresh_from_db() - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) self.assertIn("Queued pipeline result", joined) self.assertIn("mirrored-task-id", joined) self.assertIn("test.reply.logged", joined) @@ -742,7 +765,7 @@ def test_tasks_fetch_log_uses_token_fingerprint_not_full_token(self): self.assertEqual(resp.status_code, 200) job.refresh_from_db() - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) # Full token key must NOT appear anywhere in logs self.assertNotIn(token.key, joined) # Fingerprint (first 8 chars + ellipsis) MUST appear @@ -769,7 +792,7 @@ def test_tasks_fetch_zero_delivered_does_not_log_to_stdout(self): job.refresh_from_db() # No Tasks fetched line should appear in stdout for a zero-delivery poll - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) self.assertNotIn("Tasks fetched", joined) def test_tasks_fetch_nonzero_delivered_logs_to_stdout(self): @@ -799,7 +822,7 @@ def test_tasks_fetch_nonzero_delivered_logs_to_stdout(self): self.assertEqual(len(resp.json()["tasks"]), 3) job.refresh_from_db() - joined = "\n".join(job.logs.stdout) + joined = joined_job_log_messages(job) self.assertIn("Tasks fetched", joined) self.assertIn("delivered=3", joined) @@ -842,7 +865,7 @@ def test_throughput_line_is_well_formed(self): _log_job_throughput(self.job, "process") self.job.refresh_from_db() - joined = "\n".join(self.job.logs.stdout) + joined = joined_job_log_messages(self.job) self.assertIn("throughput", joined) self.assertIn("processed=10/100", joined) self.assertIn("rate=2.0 imgs/min", joined) @@ -858,7 +881,7 @@ def test_throughput_skipped_when_started_at_is_none(self): _log_job_throughput(self.job, "process") self.job.refresh_from_db() - joined = "\n".join(self.job.logs.stdout) + joined = joined_job_log_messages(self.job) self.assertNotIn("throughput", joined) def test_throughput_skipped_for_non_processing_stage(self): @@ -873,7 +896,7 @@ def test_throughput_skipped_for_non_processing_stage(self): _log_job_throughput(self.job, "delay") self.job.refresh_from_db() - joined = "\n".join(self.job.logs.stdout) + joined = joined_job_log_messages(self.job) self.assertNotIn("throughput", joined) def test_throughput_with_zero_processed_reports_unknown_eta(self): @@ -888,12 +911,214 @@ def test_throughput_with_zero_processed_reports_unknown_eta(self): _log_job_throughput(self.job, "process") self.job.refresh_from_db() - joined = "\n".join(self.job.logs.stdout) + joined = joined_job_log_messages(self.job) self.assertIn("processed=0/50", joined) self.assertIn("rate=0.0", joined) self.assertIn("ETA=unknown", joined) +class TestJobLogPersistence(TestCase): + """Exercise the JobLog table / legacy-JSON fallback paths on JobLogHandler.emit.""" + + def setUp(self): + self.project = Project.objects.create(name="JobLog Test Project") + self.pipeline = Pipeline.objects.create(name="JobLog Pipeline", slug="joblog-pipeline") + self.pipeline.projects.add(self.project) + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="JobLog emit test job", + pipeline=self.pipeline, + ) + + def test_emit_inserts_one_joblog_row_per_call(self): + self.job.logger.info("first") + self.job.logger.error("boom") + + rows = list(JobLog.objects.filter(job=self.job).order_by("pk").values("level", "message")) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["level"], "INFO") + self.assertIn("first", rows[0]["message"]) + self.assertEqual(rows[1]["level"], "ERROR") + self.assertIn("boom", rows[1]["message"]) + + # emit must not repopulate the legacy JSON column. + self.job.refresh_from_db(fields=["logs"]) + self.assertEqual(self.job.logs.stdout, []) + self.assertEqual(self.job.logs.stderr, []) + + def test_flag_disabled_short_circuits_emit(self): + from django.test import override_settings + + with override_settings(JOB_LOG_PERSIST_ENABLED=False): + self.job.logger.info("suppressed") + self.job.logger.error("also suppressed") + + self.assertFalse(JobLog.objects.filter(job=self.job).exists()) + self.job.refresh_from_db(fields=["logs"]) + self.assertEqual(self.job.logs.stdout, []) + self.assertEqual(self.job.logs.stderr, []) + + def test_serialize_job_logs_reads_from_joblog_table(self): + from ami.jobs.serializers import serialize_job_logs + + self.job.logger.info("hello world") + self.job.logger.error("something failed") + + logs = serialize_job_logs(self.job) + + self.assertEqual(len(logs["stdout"]), 2) + # Newest-first ordering. + self.assertIn("ERROR", logs["stdout"][0]) + self.assertIn("something failed", logs["stdout"][0]) + self.assertIn("INFO", logs["stdout"][1]) + self.assertIn("hello world", logs["stdout"][1]) + self.assertEqual(logs["stderr"], ["something failed"]) + + def test_serialize_job_logs_falls_back_to_legacy_json(self): + """A job with no JobLog rows but a populated ``logs`` JSON column (a + pre-migration job, or a job written under ``JOB_LOG_PERSIST_ENABLED=False`` + after legacy data had been seeded) still renders through the serializer.""" + from ami.jobs.models import JobLogs as JobLogsSchema + from ami.jobs.serializers import serialize_job_logs + + self.job.logs = JobLogsSchema(stdout=["[2025-01-01 00:00:00] INFO legacy line"], stderr=["old error"]) + self.job.save(update_fields=["logs"]) + self.assertFalse(JobLog.objects.filter(job=self.job).exists()) + + logs = serialize_job_logs(self.job) + + self.assertEqual(logs["stdout"], ["[2025-01-01 00:00:00] INFO legacy line"]) + self.assertEqual(logs["stderr"], ["old error"]) + + def test_get_logs_list_action_skips_joblog_query(self): + """The ``get_logs`` method on JobListSerializer returns the legacy JSON + shape when the viewset action is ``list``. This avoids N+1 on joined + log rows and matches UI expectations (the list view does not render logs).""" + from unittest.mock import MagicMock + + from ami.jobs.models import JobLogs as JobLogsSchema + from ami.jobs.serializers import JobListSerializer + + self.job.logger.info("ignored in list view") + self.assertEqual(JobLog.objects.filter(job=self.job).count(), 1) + + self.job.logs = JobLogsSchema(stdout=["legacy-only"], stderr=[]) + self.job.save(update_fields=["logs"]) + + # Directly instantiate the serializer with a fake view context claiming + # the list action; confirms list responses do not hit JobLog rows. + fake_view = MagicMock() + fake_view.action = "list" + serializer = JobListSerializer(instance=self.job, context={"view": fake_view}) + logs = serializer.get_logs(self.job) + + self.assertEqual(logs["stdout"], ["legacy-only"]) + self.assertEqual(logs["stderr"], []) + + def test_get_logs_detail_action_reads_joblog_table(self): + from unittest.mock import MagicMock + + from ami.jobs.serializers import JobListSerializer + + self.job.logger.info("detail view reads me") + + fake_view = MagicMock() + fake_view.action = "retrieve" + serializer = JobListSerializer(instance=self.job, context={"view": fake_view}) + logs = serializer.get_logs(self.job) + + self.assertEqual(len(logs["stdout"]), 1) + self.assertIn("detail view reads me", logs["stdout"][0]) + + def _make_detail_serializer(self, logs_limit: int | None = None): + # Mirror what JobViewSet.get_serializer_context produces for a + # detail (retrieve) action: ``logs_limit`` is the validated int (or + # None when the param was not passed). + from unittest.mock import MagicMock + + from ami.jobs.serializers import JobListSerializer + + fake_view = MagicMock() + fake_view.action = "retrieve" + return JobListSerializer( + instance=self.job, + context={"view": fake_view, "logs_limit": logs_limit}, + ) + + def test_logs_limit_caps_response_size(self): + for i in range(5): + self.job.logger.info(f"line {i}") + self.assertEqual(JobLog.objects.filter(job=self.job).count(), 5) + + serializer = self._make_detail_serializer(logs_limit=2) + logs = serializer.get_logs(self.job) + + self.assertEqual(len(logs["stdout"]), 2) + # Newest-first. + self.assertIn("line 4", logs["stdout"][0]) + self.assertIn("line 3", logs["stdout"][1]) + + def test_logs_limit_default_when_unset(self): + from ami.jobs.models import JOB_LOGS_DEFAULT_LIMIT + + self.job.logger.info("only one") + + serializer = self._make_detail_serializer(logs_limit=None) + logs = serializer.get_logs(self.job) + + # Default kicks in (no truncation; 1 < 1000). + self.assertEqual(len(logs["stdout"]), 1) + self.assertGreaterEqual(JOB_LOGS_DEFAULT_LIMIT, 1) + + +class TestJobLogsLimitHTTPValidation(APITestCase): + """``?logs_limit=`` validation runs at the view boundary, so a bad value + must produce HTTP 400 (not 500). Validated via the actual API path rather + than calling the serializer directly.""" + + def setUp(self): + self.project = Project.objects.create(name="logs_limit HTTP test") + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="logs_limit HTTP test job", + ) + self.user = User.objects.create_user( # type: ignore + email="logs-limit-validator@insectai.org", + is_staff=True, + is_active=True, + is_superuser=True, + ) + self.client.force_authenticate(user=self.user) + + def _detail_url(self, **params) -> str: + return reverse_with_params("api:job-detail", args=[self.job.pk], params=params) + + def test_valid_integer_returns_200(self): + # Sanity: a well-formed ``?logs_limit=`` does not 400 on its own. + resp = self.client.get(self._detail_url(project_id=self.project.pk, logs_limit=5)) + self.assertEqual(resp.status_code, 200) + + def test_non_integer_returns_400(self): + resp = self.client.get(self._detail_url(project_id=self.project.pk, logs_limit="abc")) + self.assertEqual(resp.status_code, 400) + + def test_zero_returns_400(self): + resp = self.client.get(self._detail_url(project_id=self.project.pk, logs_limit=0)) + self.assertEqual(resp.status_code, 400) + + def test_negative_returns_400(self): + resp = self.client.get(self._detail_url(project_id=self.project.pk, logs_limit=-5)) + self.assertEqual(resp.status_code, 400) + + def test_above_max_returns_400(self): + from ami.jobs.models import JOB_LOGS_MAX_LIMIT + + resp = self.client.get(self._detail_url(project_id=self.project.pk, logs_limit=JOB_LOGS_MAX_LIMIT + 1)) + self.assertEqual(resp.status_code, 400) + + class TestJobDispatchModeFiltering(APITestCase): """Test job filtering by dispatch_mode.""" diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 44516ca14..35f91ef96 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -7,10 +7,10 @@ from django.core.cache import cache from django.db.models import Q from django.db.models.query import QuerySet -from django.forms import IntegerField from django.utils import timezone from django_filters import rest_framework as filters -from drf_spectacular.utils import extend_schema +from drf_spectacular.utils import extend_schema, extend_schema_view +from rest_framework import serializers from rest_framework.decorators import action from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.filters import BaseFilterBackend @@ -18,8 +18,10 @@ from ami.base.pagination import LimitOffsetPaginationWithPermissions from ami.base.permissions import ObjectPermission +from ami.base.serializers import SingleParamSerializer from ami.base.views import ProjectMixin -from ami.jobs.schemas import ids_only_param, incomplete_only_param +from ami.jobs.models import JOB_LOGS_MAX_LIMIT +from ami.jobs.schemas import ids_only_param, incomplete_only_param, logs_limit_param from ami.jobs.serializers import ( MLJobResultsRequestSerializer, MLJobResultsResponseSerializer, @@ -178,6 +180,9 @@ def filter_queryset(self, request, queryset, view): return queryset +@extend_schema_view( + retrieve=extend_schema(parameters=[logs_limit_param]), +) class JobViewSet(DefaultViewSet, ProjectMixin): """ API endpoint that allows jobs to be viewed or edited. @@ -234,6 +239,21 @@ def get_serializer_class(self): else: return JobSerializer + def get_serializer_context(self): + # Validate ``?logs_limit=`` once at the view boundary so a bad value + # raises DRF ``ValidationError`` (→ HTTP 400) before the serializer + # runs. ``get_logs`` then reads the cleaned value off context. Pattern + # mirrors ``ami.base.views.get_active_project`` and the other + # ``SingleParamSerializer`` callers in ``ami/main/api/views.py``. + context = super().get_serializer_context() + if self.action == "retrieve" and self.request is not None: + context["logs_limit"] = SingleParamSerializer[int].clean( + param_name="logs_limit", + field=serializers.IntegerField(required=False, min_value=1, max_value=JOB_LOGS_MAX_LIMIT), + data=self.request.query_params, + ) + return context + @action(detail=True, methods=["post"], name="run") def run(self, request, pk=None): """ @@ -298,8 +318,13 @@ def get_queryset(self) -> QuerySet: project = self.get_active_project() if project: jobs = jobs.filter(project=project) - cutoff_hours = IntegerField(required=False, min_value=0).clean( - self.request.query_params.get("cutoff_hours", Job.FAILED_JOBS_DISPLAY_MAX_HOURS) + # Validate via SingleParamSerializer so a bad value 400s instead of + # 500ing through django.forms.IntegerField (raises django.core + # ValidationError, which DRF's default handler does NOT convert). + cutoff_hours = SingleParamSerializer[int].clean( + param_name="cutoff_hours", + field=serializers.IntegerField(required=False, min_value=0, default=Job.FAILED_JOBS_DISPLAY_MAX_HOURS), + data=self.request.query_params, ) # Filter out completed jobs that have not been updated in the last X hours cutoff_datetime = timezone.now() - timezone.timedelta(hours=cutoff_hours) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 18549d8cb..0eaa0a237 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -279,9 +279,9 @@ def test_repeated_registration_updates_last_seen(self): class TestPipelineWithProcessingService(TestCase): def test_run_pipeline_with_errors_from_processing_service(self): """ - Run a real pipeline and verify that if an error occurs for one image, the error is logged in job.logs.stderr. + Run a real pipeline and verify that if an error occurs for one image, the error is logged to JobLog. """ - from ami.jobs.models import Job + from ami.jobs.models import Job, JobLog # Setup test project, images, and job project, deployment = setup_test_project() @@ -305,11 +305,13 @@ def test_run_pipeline_with_errors_from_processing_service(self): pass # Expected if the backend raises job.refresh_from_db() - stderr_logs = job.logs.stderr + stderr_logs = list( + JobLog.objects.filter(job=job, level__in=["ERROR", "CRITICAL"]).values_list("message", flat=True) + ) # Check that an error message mentioning the failed image is present assert any( "Failed to process" in log for log in stderr_logs - ), f"Expected error message in job.logs.stderr, got: {stderr_logs}" + ), f"Expected error message in job logs, got: {stderr_logs}" def setUp(self): self.project, self.deployment = setup_test_project()