diff --git a/backend/src/api/routes/tasks.py b/backend/src/api/routes/tasks.py index 42be7abc..0dbe523e 100644 --- a/backend/src/api/routes/tasks.py +++ b/backend/src/api/routes/tasks.py @@ -658,7 +658,12 @@ async def split_clip( async def merge_clips( task_id: str, request: Request, db: AsyncSession = Depends(get_db) ): - """Merge multiple clips into one clip.""" + """Merge clips synchronously. Kept for back-compat — prefer /merge_async. + + The ffmpeg concat-encode regularly exceeds the ALB idle timeout for + multi-clip composites and surfaces as a 504 here. New callers should + use the async variant. + """ try: payload = await request.json() clip_ids = payload.get("clip_ids") or [] @@ -678,6 +683,152 @@ async def merge_clips( raise HTTPException(status_code=500, detail=f"Error merging clips: {str(e)}") +@router.post("/{task_id}/clips/merge_async", status_code=202) +async def merge_clips_async( + task_id: str, request: Request, db: AsyncSession = Depends(get_db) +): + """Enqueue a merge job and return immediately. + + Poll status via GET /tasks/{task_id}/clips/merge_jobs/{merge_job_id}. + Validation (ownership, clip existence) runs synchronously so bad + requests fail fast instead of burning a worker slot. + """ + try: + try: + payload = await request.json() + except json.JSONDecodeError as exc: + # Without this, the bare `except Exception` below converts + # client malformed-body errors into 500s. + raise HTTPException(status_code=400, detail="Malformed JSON body") from exc + if not isinstance(payload, dict): + # JSON arrays / scalars are syntactically valid but don't + # have .get() — without this guard payload.get below + # AttributeErrors into the 500 fallback. + raise HTTPException(status_code=400, detail="JSON body must be an object") + clip_ids = payload.get("clip_ids") or [] + if not isinstance(clip_ids, list): + raise HTTPException(status_code=400, detail="clip_ids must be an array") + if len(clip_ids) < 2: + raise HTTPException( + status_code=400, detail="At least two clips are required to merge" + ) + + task_service = TaskService(db) + await _require_task_owner(request, task_service, db, task_id) + + for clip_id in clip_ids: + clip = await task_service.clip_repo.get_clip_by_id(db, clip_id) + if not clip or clip["task_id"] != task_id: + raise HTTPException( + status_code=404, detail=f"Clip {clip_id} not found on task" + ) + + merge_job_id = await JobQueue.enqueue_job( + "merge_clips_job", task_id, clip_ids + ) + logger.info( + f"Enqueued merge job {merge_job_id} task={task_id} clips={len(clip_ids)}" + ) + return {"merge_job_id": merge_job_id, "status": "queued"} + except HTTPException: + raise + except Exception as e: + logger.error(f"Error enqueueing merge: {e}") + raise HTTPException( + status_code=500, detail=f"Error enqueueing merge: {str(e)}" + ) + + +@router.get("/{task_id}/clips/merge_jobs/{merge_job_id}") +async def get_merge_job( + task_id: str, + merge_job_id: str, + request: Request, + db: AsyncSession = Depends(get_db), +): + """Poll a queued merge. + + Status values mirror arq's JobStatus enum: `deferred | queued | + in_progress | complete`. Missing jobs (unknown id, or job whose + Redis state has expired) are returned as **HTTP 404**, not a + `not_found` status — clients should treat 404 as the + job-doesn't-exist signal. + + On `complete` the response carries either `clip_id` + `message` + (success) or `error` (worker exception, surfaced as the str() of + the raised exception). + """ + try: + task_service = TaskService(db) + await _require_task_owner(request, task_service, db, task_id) + + # Bind the job to the path task before exposing its status/result. + # Owning *any* task isn't enough — without this check, a caller who + # guessed a merge_job_id could probe a job that belongs to a + # different task (and learn its merged clip_id). We verify via + # arq's stored JobDef which carries the original args we passed + # at enqueue (task_id, clip_ids) — no extra persistence layer + # needed. + info = await JobQueue.get_job_info(merge_job_id) + if info is None: + raise HTTPException( + status_code=404, detail=f"Merge job {merge_job_id} not found" + ) + if info.function != "merge_clips_job": + # Wrong fn => not a merge job at all; treat as not-found rather + # than leak which functions exist. + raise HTTPException( + status_code=404, detail=f"Merge job {merge_job_id} not found" + ) + # args === (task_id, clip_ids) per merge_clips_job signature. + # Mismatch means the caller is asking about a job that belongs + # to a different task — pretend it doesn't exist. + if not info.args or info.args[0] != task_id: + raise HTTPException( + status_code=404, detail=f"Merge job {merge_job_id} not found" + ) + + # get_job_status normalises arq's JobStatus enum to a lowercase + # string (and returns None for not_found), so we can consume the + # value directly. + status_str = await JobQueue.get_job_status(merge_job_id) + if status_str is None: + # Race: arq evicted the job's status entry between our info() + # call and now. Treat as not-found. + raise HTTPException( + status_code=404, detail=f"Merge job {merge_job_id} not found" + ) + + response: Dict[str, Any] = { + "merge_job_id": merge_job_id, + "status": status_str, + } + + if status_str == "complete": + try: + result = await JobQueue.get_job_result(merge_job_id) + if isinstance(result, dict): + response["clip_id"] = result.get("clip_id") + response["message"] = result.get("message") + else: + response["error"] = ( + f"Unexpected worker result type: {type(result).__name__}" + ) + except Exception as exc: + # arq raises the original worker exception when the job + # ended in failure; expose its string form to the caller. + response["error"] = str(exc) + + return response + except HTTPException: + raise + except Exception as e: + logger.error(f"Error fetching merge job status: {e}") + raise HTTPException( + status_code=500, detail=f"Error fetching merge job status: {str(e)}" + ) + + @router.patch("/{task_id}/clips/{clip_id}/captions") async def update_clip_captions( task_id: str, clip_id: str, request: Request, db: AsyncSession = Depends(get_db) diff --git a/backend/src/workers/job_queue.py b/backend/src/workers/job_queue.py index 636e3edb..421c95ee 100644 --- a/backend/src/workers/job_queue.py +++ b/backend/src/workers/job_queue.py @@ -6,6 +6,7 @@ from typing import Optional from arq import create_pool from arq.connections import RedisSettings, ArqRedis +from arq.jobs import Job from ..config import get_config logger = logging.getLogger(__name__) @@ -82,20 +83,57 @@ async def enqueue_processing_job( function_name, *args, _queue_name=queue_name, **kwargs ) + @classmethod + def _job(cls, pool: ArqRedis, job_id: str) -> Job: + """Construct an arq Job handle for a given id. + + ArqRedis itself has no `.job()` method (despite the obvious + name); the public API for looking up an existing job is the + Job(job_id=..., redis=pool) constructor. The Job handle is + cheap to create — it's just a pair of references — and reading + from Redis happens lazily on info()/status()/result(). + """ + return Job(job_id=job_id, redis=pool) + @classmethod async def get_job_result(cls, job_id: str): - """Get the result of a completed job.""" + """Return the worker function's return value, or re-raise its exception.""" pool = await cls.get_pool() - job = await pool.job(job_id) - if job: - return await job.result() - return None + return await cls._job(pool, job_id).result() @classmethod async def get_job_status(cls, job_id: str) -> Optional[str]: - """Get the status of a job.""" + """Return arq's JobStatus as a lowercase string, or None if unknown. + + Normalising the enum -> str at the JobQueue boundary lets route + handlers consume the value directly without importing arq + internals (and prevents the easy bug of returning the enum object + through an Optional[str] signature). + """ + pool = await cls.get_pool() + status = await cls._job(pool, job_id).status() + if status is None: + return None + # arq.jobs.JobStatus renders as "JobStatus.complete" etc. Take + # the suffix and lowercase it for a stable wire shape. + status_str = str(status).split(".")[-1].lower() + # JobStatus.not_found is how arq signals a missing job — surface + # that as None at this boundary too. + if status_str == "not_found": + return None + return status_str + + @classmethod + async def get_job_info(cls, job_id: str): + """Return the JobDef (function name + args/kwargs) for a job. + + Used to verify a polling request is authorised for the job it + names — callers can match args[N] against the path parameter + that should own the job, without needing a separate persistence + layer for the task↔job association. arq stores the job def in + Redis as long as the job exists or its result is still cached. + + Returns None if the job is unknown to Redis. + """ pool = await cls.get_pool() - job = await pool.job(job_id) - if job: - return await job.status() - return None + return await cls._job(pool, job_id).info() diff --git a/backend/src/workers/tasks.py b/backend/src/workers/tasks.py index fc62aeb4..83496d92 100644 --- a/backend/src/workers/tasks.py +++ b/backend/src/workers/tasks.py @@ -118,6 +118,42 @@ async def clip_ready_callback( # Error will be caught by arq and task status will be updated raise +async def merge_clips_job( + ctx: Dict[str, Any], + task_id: str, + clip_ids: list[str], +) -> Dict[str, Any]: + """ + Background worker task to merge clips. + + The synchronous /tasks/{task_id}/clips/merge endpoint blocks the HTTP + request for the full ffmpeg concat-encode duration, which routinely + exceeds the ALB idle timeout (60s default, 300s after the band-aid + bump) and surfaces as a 504 to the caller. This worker variant is + enqueued by /tasks/{task_id}/clips/merge_async and polled via + /tasks/{task_id}/clips/merge_jobs/{job_id} so callers never hold an + HTTP connection open for the encode. + + Returns the same dict shape as TaskService.merge_clips so arq's + job result storage carries the merged_clip_id straight to the poller. + """ + from ..database import AsyncSessionLocal + from ..runtime_settings import load_runtime_settings_cache + from ..services.task_service import TaskService + + set_trace_id(f"merge-{task_id}") + logger.info(f"Worker merging {len(clip_ids)} clips for task {task_id}") + + async with AsyncSessionLocal() as db: + await load_runtime_settings_cache(db) + task_service = TaskService(db) + result = await task_service.merge_clips(task_id, clip_ids) + logger.info( + f"Merge complete task={task_id} merged_clip_id={result.get('clip_id')}" + ) + return result + + # Worker configuration for arq class WorkerSettings: """Configuration for arq worker.""" @@ -128,7 +164,7 @@ class WorkerSettings: config = Config() # Functions to run - functions = [process_video_task] + functions = [process_video_task, merge_clips_job] queue_name = "supoclip_tasks" # Redis settings from environment diff --git a/backend/tests/integration/test_health_and_tasks.py b/backend/tests/integration/test_health_and_tasks.py index 5366f905..1c1cb5f6 100644 --- a/backend/tests/integration/test_health_and_tasks.py +++ b/backend/tests/integration/test_health_and_tasks.py @@ -1,6 +1,23 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + import pytest -from tests.fixtures.factories import create_source, create_task, create_user +from tests.fixtures.factories import ( + create_clip, + create_source, + create_task, + create_user, +) + + +def _job_info(task_id: str, clip_ids: list[str] | None = None, function: str = "merge_clips_job"): + """Stand-in for arq's JobDef used by JobQueue.get_job_info. + + Only the attributes the endpoint actually reads (function, args) + need to be populated; everything else stays out of the test surface. + """ + return SimpleNamespace(function=function, args=(task_id, clip_ids or [])) @pytest.mark.asyncio @@ -99,3 +116,250 @@ async def test_upload_video_uses_runtime_config_temp_dir( payload = response.json() saved_name = payload["video_path"].removeprefix("upload://") assert (tmp_path / "uploads" / saved_name).exists() + + +@pytest.mark.asyncio +async def test_merge_async_enqueues_and_returns_job_id( + client, db_session, auth_headers +): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + clip_a = await create_clip(db_session, task_id=task["id"]) + clip_b = await create_clip(db_session, task_id=task["id"]) + + with patch( + "src.api.routes.tasks.JobQueue.enqueue_job", + new=AsyncMock(return_value="merge-job-xyz"), + ) as enqueue: + response = await client.post( + f"/tasks/{task['id']}/clips/merge_async", + headers=auth_headers, + json={"clip_ids": [clip_a["id"], clip_b["id"]]}, + ) + + assert response.status_code == 202 + payload = response.json() + assert payload == {"merge_job_id": "merge-job-xyz", "status": "queued"} + enqueue.assert_awaited_once_with( + "merge_clips_job", task["id"], [clip_a["id"], clip_b["id"]] + ) + + +@pytest.mark.asyncio +async def test_merge_async_rejects_unknown_clip(client, db_session, auth_headers): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + clip = await create_clip(db_session, task_id=task["id"]) + + # Don't even hit the queue if validation fails — guards against + # burning a worker slot to discover a typo. + with patch( + "src.api.routes.tasks.JobQueue.enqueue_job", + new=AsyncMock(return_value="should-not-be-called"), + ) as enqueue: + response = await client.post( + f"/tasks/{task['id']}/clips/merge_async", + headers=auth_headers, + json={"clip_ids": [clip["id"], "ghost-clip-id"]}, + ) + + assert response.status_code == 404 + assert "ghost-clip-id" in response.json()["detail"] + enqueue.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_merge_async_rejects_malformed_json(client, db_session, auth_headers): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + response = await client.post( + f"/tasks/{task['id']}/clips/merge_async", + headers={**auth_headers, "Content-Type": "application/json"}, + content=b"{not valid json", + ) + + assert response.status_code == 400 + assert "JSON" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_merge_async_rejects_non_object_body(client, db_session, auth_headers): + """JSON arrays / scalars are syntactically valid but have no .get().""" + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + response = await client.post( + f"/tasks/{task['id']}/clips/merge_async", + headers=auth_headers, + json=["clip-a", "clip-b"], # array, not object + ) + + assert response.status_code == 400 + assert "object" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_merge_async_rejects_single_clip(client, db_session, auth_headers): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + clip = await create_clip(db_session, task_id=task["id"]) + + response = await client.post( + f"/tasks/{task['id']}/clips/merge_async", + headers=auth_headers, + json={"clip_ids": [clip["id"]]}, + ) + + assert response.status_code == 400 + assert "two" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_get_merge_job_returns_completion_result( + client, db_session, auth_headers +): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + with patch( + "src.api.routes.tasks.JobQueue.get_job_info", + new=AsyncMock(return_value=_job_info(task["id"])), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_status", + new=AsyncMock(return_value="complete"), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_result", + new=AsyncMock(return_value={"clip_id": "merged-1", "message": "ok"}), + ): + response = await client.get( + f"/tasks/{task['id']}/clips/merge_jobs/job-abc", + headers=auth_headers, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload == { + "merge_job_id": "job-abc", + "status": "complete", + "clip_id": "merged-1", + "message": "ok", + } + + +@pytest.mark.asyncio +async def test_get_merge_job_surfaces_worker_error(client, db_session, auth_headers): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + with patch( + "src.api.routes.tasks.JobQueue.get_job_info", + new=AsyncMock(return_value=_job_info(task["id"])), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_status", + new=AsyncMock(return_value="complete"), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_result", + new=AsyncMock(side_effect=RuntimeError("ffmpeg exit 254")), + ): + response = await client.get( + f"/tasks/{task['id']}/clips/merge_jobs/job-bad", + headers=auth_headers, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "complete" + assert "ffmpeg exit 254" in payload["error"] + + +@pytest.mark.asyncio +async def test_get_merge_job_returns_404_when_unknown( + client, db_session, auth_headers +): + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + with patch( + "src.api.routes.tasks.JobQueue.get_job_info", + new=AsyncMock(return_value=None), + ): + response = await client.get( + f"/tasks/{task['id']}/clips/merge_jobs/ghost", + headers=auth_headers, + ) + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_merge_job_rejects_cross_task_probe( + client, db_session, auth_headers +): + """A job owned by another task must surface as 404, never leak status. + + Guards the path CodeRabbit flagged: with `await _require_task_owner` + on task_id alone, a user who legitimately owns task A could probe + a merge_job_id belonging to task B by hitting + /tasks/{A}/clips/merge_jobs/{B-job}. The args-binding check in the + endpoint catches it before status leaks. + """ + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + own_task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + # Job exists in arq but its first arg names a *different* task — + # endpoint must 404, not 200 + status. + foreign_info = _job_info("a-different-task-id") + + with patch( + "src.api.routes.tasks.JobQueue.get_job_info", + new=AsyncMock(return_value=foreign_info), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_status", + new=AsyncMock(return_value="complete"), + ) as status_mock: + response = await client.get( + f"/tasks/{own_task['id']}/clips/merge_jobs/foreign-job", + headers=auth_headers, + ) + + assert response.status_code == 404 + # Verify we shortcircuited before fetching status — no leak even + # via timing or error shape. + status_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_merge_job_rejects_wrong_function(client, db_session, auth_headers): + """A job id that exists but belongs to a different worker function 404s. + + Prevents callers from using the merge polling endpoint as a generic + job introspection oracle (e.g. probing process_video_task ids). + """ + await create_user(db_session, user_id="user-1", email="owner@example.com") + source = await create_source(db_session, title="Owner source") + task = await create_task(db_session, user_id="user-1", source_id=source["id"]) + + with patch( + "src.api.routes.tasks.JobQueue.get_job_info", + new=AsyncMock(return_value=_job_info(task["id"], function="process_video_task")), + ), patch( + "src.api.routes.tasks.JobQueue.get_job_status", + new=AsyncMock(return_value="complete"), + ) as status_mock: + response = await client.get( + f"/tasks/{task['id']}/clips/merge_jobs/wrong-fn-job", + headers=auth_headers, + ) + + assert response.status_code == 404 + status_mock.assert_not_awaited()