From e4a8709a146a64ad4687723a9e036d3268491cf1 Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 1 May 2026 13:50:32 -0400 Subject: [PATCH 1/7] build: Add boto3 and moto for ECS Fargate runtime The ECS Fargate worker fleet needs boto3 at runtime for the S3 cache backend, the ECS provisioner, and the worker discovery loop. moto is added as a dev-only dependency so the unit tests can exercise the boto3 code paths without standing up LocalStack or hitting real AWS. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6dda16c..52308b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ classifiers = [ ] dependencies = [ "aiohttp", + "boto3", "click", "debugpy", "fastapi", @@ -30,7 +31,7 @@ readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" [project.optional-dependencies] -dev = ["allpairspy", "debugpy", "httpx", "hypothesis", "mongomock-motor", "pytest-asyncio", "pytest-mock", "ruff"] +dev = ["allpairspy", "debugpy", "httpx", "hypothesis", "mongomock-motor", "moto[ecs,s3]", "pytest-asyncio", "pytest-mock", "ruff"] [project.scripts] cfdb = "cfdb.cli:cli" From 57402bc380807756cf2b1c14f603691d5ad36ebd Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 1 May 2026 13:50:44 -0400 Subject: [PATCH 2/7] feat: Add S3 cache backend for workflow artifacts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit S3Cache is a CacheBackend implementation over boto3 with the same range-aware semantics as LocalFsCache (head_object, get_object with a Range header, upload_file, delete_object). Production points it at real S3; LocalStack-backed dev points it at the LocalStack endpoint via AWS_ENDPOINT_URL — only the endpoint differs. The backend supports an optional key prefix for sharing a single bucket across environments, rejects path-traversal segments, and yields an empty iterator on missing objects so router code can treat cache misses uniformly across backends. --- src/cfdb/workflows/cache.py | 327 +++++++++++++++++++++++++- tests/test_workflows/test_cache_s3.py | 240 +++++++++++++++++++ 2 files changed, 557 insertions(+), 10 deletions(-) create mode 100644 tests/test_workflows/test_cache_s3.py diff --git a/src/cfdb/workflows/cache.py b/src/cfdb/workflows/cache.py index 5787038..f675aa6 100644 --- a/src/cfdb/workflows/cache.py +++ b/src/cfdb/workflows/cache.py @@ -1,9 +1,10 @@ """Pluggable cache backend for workflow artifacts. The cache is a byte-range-aware content store keyed by strings produced by -``workflows.keys.cache_key``. ``LocalFsCache`` is the concrete backend used -in local development and for the initial CVH rollout; an S3-backed -implementation with the same interface is planned for production. +``workflows.keys.cache_key``. Two concrete backends ship: ``LocalFsCache`` +for development and unit tests that don't need a network round-trip, and +``S3Cache`` for production (and for LocalStack-backed dev that mirrors +production end-to-end via boto3). Range-aware reads matter because Gosling's client-side fetchers (BAM/tabix families, bbi) issue ``Range: bytes=…`` requests against the artifact URL. @@ -19,7 +20,10 @@ from collections.abc import AsyncIterator from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Any, Optional + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError @dataclass(frozen=True) @@ -74,17 +78,29 @@ async def delete(self, key: str) -> bool: _CHUNK_SIZE = 1 << 16 # 64 KiB -def _safe_key_path(root: Path, key: str) -> Path: - """Resolve a cache key to an absolute path under ``root``. +def _validate_cache_key(key: str) -> None: + """Reject path-traversal segments in a cache key. Rejects empty keys and keys containing path-traversal segments so that malformed input cannot escape the cache root or collapse onto the - root directory itself. + root directory itself. Shared by both backends so the rule (and its + error message) stays in one place. """ if not key or not key.strip("/"): raise ValueError(f"Cache key must be non-empty: {key!r}") if ".." in key.split("/"): raise ValueError(f"Invalid cache key: {key!r}") + if key.startswith("/"): + raise ValueError(f"Cache key must not start with '/': {key!r}") + + +def _safe_key_path(root: Path, key: str) -> Path: + """Resolve a cache key to an absolute path under ``root``. + + Validates the key shape and verifies the resolved path stays under + ``root`` so a malformed key cannot escape the cache root. + """ + _validate_cache_key(key) path = (root / key).resolve() root_resolved = root.resolve() if root_resolved not in path.parents: @@ -95,8 +111,9 @@ def _safe_key_path(root: Path, key: str) -> Path: class LocalFsCache(CacheBackend): """Local-filesystem backend. Keys map directly to relative paths. - ``put`` writes via a temp sibling + atomic rename. ``get`` supports - byte-range reads and chunks by 64 KiB to keep streaming memory-bounded. + ``put`` atomically renames the caller's source file into place via + ``os.replace``. ``get`` supports byte-range reads and chunks by 64 + KiB to keep streaming memory-bounded. """ def __init__(self, root: Path) -> None: @@ -143,6 +160,293 @@ async def delete(self, key: str) -> bool: return False +class S3Cache(CacheBackend): + """Boto3-backed cache for production (and LocalStack-backed dev). + + The same code targets real S3 and LocalStack — the only difference + is the ``endpoint_url`` passed to ``boto3.client("s3")``. Keys are + stored as object keys (optionally under a configurable prefix); + range reads use S3's ``Range`` header verbatim, so the fetcher + semantics are identical to ``LocalFsCache``. + + Args: + bucket: Bucket name. Must already exist (LocalStack and prod + both treat bucket creation as an out-of-band concern). + prefix: Optional key prefix; useful for sharing a single bucket + across multiple environments. Empty string by default. + client: Optional pre-built boto3 ``s3`` client. When omitted, + one is constructed via :func:`_build_s3_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + Tests inject a moto-backed client through this argument. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`_build_s3_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here so + LocalStack vs production differ only at this seam. + region_name: Boto3 ``region_name``. Plumbed analogously. + chunk_size: Streaming chunk size for ``get`` reads. Defaults to + 64 KiB to match the local filesystem backend. + """ + + def __init__( + self, + bucket: str, + *, + prefix: str = "", + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + chunk_size: int = _CHUNK_SIZE, + ) -> None: + if not bucket: + raise ValueError("S3Cache requires a non-empty bucket name") + self._bucket = bucket + # Normalize so callers can pass either ``"prefix"`` or ``"prefix/"``. + self._prefix = prefix.strip("/") + self._endpoint_url = endpoint_url + self._region_name = region_name + self._client = ( + client + if client is not None + else _build_s3_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._chunk_size = chunk_size + + def _object_key(self, key: str) -> str: + """Apply the configured prefix to a validated cache key.""" + _validate_cache_key(key) + return f"{self._prefix}/{key}" if self._prefix else key + + def __getstate__(self) -> dict[str, Any]: + """Strip the boto3 client for pickling. + + ``S3Cache`` is dispatched across the cloudpickle boundary into + Wool worker processes; botocore's ``BaseClient`` cannot be + pickled. ``__setstate__`` rebuilds it via ``_build_s3_client`` + during unpickling, threading the originally-supplied + ``endpoint_url`` / ``region_name`` through so the worker + targets the same backend as the API process. + """ + state = self.__dict__.copy() + state["_client"] = None + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restore state and rebuild the boto3 client on the worker.""" + self.__dict__.update(state) + if self._client is None: + self._client = _build_s3_client( + endpoint_url=self._endpoint_url, + region_name=self._region_name, + ) + + async def head(self, key: str) -> Optional[CacheEntry]: + """Return cache metadata for ``key``, or None if the object is absent. + + ``ClientError`` covers HTTP-level S3 errors with structured + response codes; ``BotoCoreError`` covers transport / cred + failures with no ``.response``. ``_is_not_found`` returns + False for the latter family, so transport failures correctly + re-raise rather than masquerade as cache miss. + + Note: S3 ``HEAD`` responses carry no body, so a missing + bucket is indistinguishable from a missing object at this + endpoint — both surface as a bare ``404``. The lifespan + startup probes the bucket separately (see + ``check_s3_bucket_or_raise``) so a typo in + ``WORKFLOW_S3_BUCKET`` fails fast at boot rather than as a + cascade of "permanent cache miss" symptoms. + """ + object_key = self._object_key(key) + try: + response = await asyncio.to_thread( + self._client.head_object, + Bucket=self._bucket, + Key=object_key, + ) + except (ClientError, BotoCoreError) as exc: + if _is_not_found(exc): + return None + raise + return CacheEntry(key=key, size=int(response["ContentLength"])) + + def get( + self, key: str, byte_range: Optional[tuple[int, int]] = None + ) -> AsyncIterator[bytes]: + """Stream cached bytes from S3, optionally restricted to a byte range. + + ``GetObject`` accepts an inclusive ``Range`` header; we forward + the tuple verbatim. A missing object yields an empty iterator, + matching ``LocalFsCache`` semantics so router code can treat + cache misses uniformly across backends. + """ + return _stream_s3_object( + self._client, + self._bucket, + self._object_key(key), + byte_range, + self._chunk_size, + ) + + async def put(self, key: str, source_path: Path) -> CacheEntry: + """Upload ``source_path`` to S3 under the configured key. + + ``upload_file`` is atomic from the reader's perspective — + readers either see the prior object (if any) or the new one, + never a partial write. Boto3 streams the file in a thread so + the event loop isn't blocked. Size is taken from the source + file rather than a follow-up ``HEAD`` to avoid the round-trip; + when the source file has been torn down between upload + completion and the local ``stat`` (workdir cleanup races + finalization), we fall back to a ``head_object`` round-trip + rather than surface ``FileNotFoundError`` for an upload that + is already committed. + """ + object_key = self._object_key(key) + await asyncio.to_thread( + self._client.upload_file, + str(source_path), + self._bucket, + object_key, + ) + try: + size = await asyncio.to_thread(source_path.stat) + except FileNotFoundError: + response = await asyncio.to_thread( + self._client.head_object, + Bucket=self._bucket, + Key=object_key, + ) + return CacheEntry(key=key, size=int(response["ContentLength"])) + return CacheEntry(key=key, size=size.st_size) + + async def delete(self, key: str) -> bool: + """Delete the cache entry. Returns True when the object existed. + + ``DeleteObject`` returns 204 whether or not the key existed, so + we probe with HEAD first to give callers the existence signal. + The HEAD/DELETE pair is non-atomic — a concurrent ``put`` between + them returns ``False`` ("did not exist") yet erases the freshly- + uploaded object. Callers MUST serialize ``put``/``delete`` for the + same key via the workflow mutex; the cache backend itself does + not arbitrate concurrent writers. + """ + object_key = self._object_key(key) + existed = await self.head(key) is not None + await asyncio.to_thread( + self._client.delete_object, + Bucket=self._bucket, + Key=object_key, + ) + return existed + + +def _build_s3_client( + *, endpoint_url: Optional[str] = None, region_name: Optional[str] = None +) -> Any: + """Construct a boto3 ``s3`` client with explicit endpoint/region. + + The caller (typically :class:`S3Cache` or the API lifespan) is the + single source of truth for ``endpoint_url`` and ``region_name`` — + we no longer reach into :mod:`cfdb.api` for fallback values. Pass + :data:`cfdb.api.AWS_ENDPOINT_URL` / :data:`cfdb.api.AWS_REGION` + from the lifespan; leave both ``None`` to let boto3's default + session resolver chain pick them up from the environment. + """ + return boto3.client( + "s3", + endpoint_url=endpoint_url, + region_name=region_name, + ) + + +async def check_s3_bucket_or_raise(bucket: str, *, client: Optional[Any] = None) -> None: + """Verify ``bucket`` is reachable; raise on missing/inaccessible. + + Run from the API lifespan so a typo in ``WORKFLOW_S3_BUCKET`` (or + a missing IAM grant) fails fast at boot rather than masquerading + as a permanent cache-miss cascade once workflows start. ``HEAD`` + on a missing bucket returns a bare ``404`` that ``S3Cache.head`` + cannot distinguish from a missing object — this probe asks + ``head_bucket`` directly, which gets a structured response. + """ + cli = client if client is not None else _build_s3_client() + try: + await asyncio.to_thread(cli.head_bucket, Bucket=bucket) + except (ClientError, BotoCoreError) as exc: + raise RuntimeError( + f"S3 bucket {bucket!r} is not reachable: {type(exc).__name__}: {exc}" + ) from exc + + +#: Object-level "missing" codes only. ``NoSuchBucket`` is deliberately +#: NOT in this set: a missing bucket is a configuration failure (typo +#: in WORKFLOW_S3_BUCKET, missing IAM, region mismatch) that should +#: surface as an exception rather than silently masquerade as a +#: permanent cache miss. +_NOT_FOUND_CODES = frozenset({"404", "NoSuchKey", "NotFound"}) + + +def _is_not_found(exc: BaseException) -> bool: + """Return True when a boto3 exception indicates a missing object. + + ``ClientError`` carries a structured ``response`` dict with both an + ``Error.Code`` (string) and ``ResponseMetadata.HTTPStatusCode`` + (int). We check both: head_object responses use the bare 404 code, + get_object uses the named ``NoSuchKey``. ``BotoCoreError`` (network + / credential failures) has no ``response`` and is correctly + classified as not-a-not-found, so it propagates. + """ + response = getattr(exc, "response", None) + if not isinstance(response, dict): + return False + code = (response.get("Error") or {}).get("Code") + if code in _NOT_FOUND_CODES: + return True + status = (response.get("ResponseMetadata") or {}).get("HTTPStatusCode") + return status == 404 + + +async def _stream_s3_object( + client: Any, + bucket: str, + object_key: str, + byte_range: Optional[tuple[int, int]], + chunk_size: int, +) -> AsyncIterator[bytes]: + """Async generator yielding chunks from an S3 object. + + Diverges from ``_stream_file`` on mid-stream object disappearance: + a deletion between the ``GetObject`` response and the body read + raises a ``ClientError`` to the consumer rather than truncating + silently. Callers serialize put/delete via the workflow mutex + (see ``S3Cache.delete``), so the divergence is benign in practice. + """ + kwargs: dict[str, Any] = {"Bucket": bucket, "Key": object_key} + if byte_range is not None: + start, end = byte_range + kwargs["Range"] = f"bytes={start}-{end}" + try: + response = await asyncio.to_thread(client.get_object, **kwargs) + except (ClientError, BotoCoreError) as exc: + if _is_not_found(exc): + return + raise + body = response["Body"] + try: + while True: + chunk = await asyncio.to_thread(body.read, chunk_size) + if not chunk: + break + yield chunk + finally: + close = getattr(body, "close", None) + if close is not None: + # Shield the close so a CancelledError in this finally + # (consumer disconnect mid-stream, chained cancel) cannot + # skip the underlying urllib3 connection release. + await asyncio.shield(asyncio.to_thread(close)) + + async def _stream_file( path: Path, byte_range: Optional[tuple[int, int]] ) -> AsyncIterator[bytes]: @@ -188,4 +492,7 @@ def _open_and_seek(): remaining -= len(chunk) yield chunk finally: - await asyncio.to_thread(fh.close) + # Shield the close so a CancelledError in this finally (consumer + # disconnect mid-stream, chained cancel) cannot skip the + # underlying file-descriptor release. + await asyncio.shield(asyncio.to_thread(fh.close)) diff --git a/tests/test_workflows/test_cache_s3.py b/tests/test_workflows/test_cache_s3.py new file mode 100644 index 0000000..aff997e --- /dev/null +++ b/tests/test_workflows/test_cache_s3.py @@ -0,0 +1,240 @@ +"""Tests for the moto-backed S3Cache implementation.""" + +from __future__ import annotations + +from pathlib import Path + +import boto3 +import pytest +from moto import mock_aws + +from cfdb.workflows.cache import CacheEntry, S3Cache + + +_BUCKET = "cfdb-test-cache" + + +def _write(path: Path, data: bytes) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + + +async def _collect(stream) -> bytes: + chunks: list[bytes] = [] + async for chunk in stream: + chunks.append(chunk) + return b"".join(chunks) + + +@pytest.fixture() +def s3_client(): + """Return a moto-backed boto3 S3 client with one created bucket.""" + with mock_aws(): + client = boto3.client("s3", region_name="us-east-1") + client.create_bucket(Bucket=_BUCKET) + yield client + + +@pytest.fixture() +def cache(s3_client) -> S3Cache: + """Return an S3Cache wired up to the moto-backed client.""" + return S3Cache(bucket=_BUCKET, client=s3_client) + + +class TestS3Cache: + def test___init___with_empty_bucket_name(self, s3_client): + """Test that S3Cache rejects an empty bucket name. + + Given: + A moto-backed S3 client and an empty bucket name. + When: + S3Cache is constructed. + Then: + It should raise ValueError so misconfigurations fail fast. + """ + # Act & assert + with pytest.raises(ValueError, match="bucket"): + S3Cache(bucket="", client=s3_client) + + @pytest.mark.asyncio + async def test_head_with_absent_key(self, cache): + """Test that head reports a cache miss as None. + + Given: + An empty S3Cache. + When: + head is awaited for an unknown key. + Then: + It should return None so the router can dispatch a workflow. + """ + # Act + entry = await cache.head("encode/x/data/aa-v0") + + # Assert + assert entry is None + + @pytest.mark.asyncio + async def test_put_then_head_with_known_payload(self, cache, tmp_path): + """Test that put commits the artifact and head reports its size. + + Given: + An S3Cache and a source file with known contents. + When: + put commits the artifact and head is then awaited. + Then: + head should return a CacheEntry carrying the exact byte size. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"hello world") + + # Act + await cache.put("encode/x/data/aa-v0", source) + entry = await cache.head("encode/x/data/aa-v0") + + # Assert + assert entry == CacheEntry(key="encode/x/data/aa-v0", size=11) + + @pytest.mark.asyncio + async def test_get_with_full_object(self, cache, tmp_path): + """Test that get streams the full artifact without a byte range. + + Given: + An S3Cache containing a multi-chunk artifact. + When: + get is iterated without a byte_range argument. + Then: + It should yield the complete bytes. + """ + # Arrange + source = tmp_path / "src" + payload = b"0123456789" * 20_000 + _write(source, payload) + await cache.put("encode/x/data/aa-v0", source) + + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0")) + + # Assert + assert collected == payload + + @pytest.mark.asyncio + async def test_get_with_inclusive_byte_range(self, cache, tmp_path): + """Test that get forwards an inclusive byte range to S3. + + Given: + An S3Cache containing a known artifact. + When: + get is iterated with byte_range=(5, 9). + Then: + It should yield exactly bytes 5..9 inclusive (5 bytes total). + """ + # Arrange + source = tmp_path / "src" + _write(source, b"0123456789ABCDEF") + await cache.put("encode/x/data/aa-v0", source) + + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0", (5, 9))) + + # Assert + assert collected == b"56789" + + @pytest.mark.asyncio + async def test_get_with_absent_key(self, cache): + """Test that get yields nothing for an absent key. + + Given: + An empty S3Cache. + When: + get is iterated for a missing key. + Then: + It should yield no chunks rather than raise. + """ + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0")) + + # Assert + assert collected == b"" + + @pytest.mark.asyncio + async def test_delete_with_present_key(self, cache, tmp_path): + """Test that delete reports True and removes the artifact. + + Given: + An S3Cache containing one entry. + When: + delete is awaited for that key. + Then: + It should return True and a subsequent head should miss. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"x") + await cache.put("encode/x/data/aa-v0", source) + + # Act + deleted = await cache.delete("encode/x/data/aa-v0") + + # Assert + assert deleted is True + assert await cache.head("encode/x/data/aa-v0") is None + + @pytest.mark.asyncio + async def test_delete_with_absent_key(self, cache): + """Test that delete on a missing key is idempotent. + + Given: + An empty S3Cache. + When: + delete is awaited for an unknown key. + Then: + It should return False rather than raise. + """ + # Act & assert + assert await cache.delete("encode/x/data/aa-v0") is False + + @pytest.mark.asyncio + async def test_put_with_traversal_segment_key(self, cache, tmp_path): + """Test that put refuses keys containing path-traversal segments. + + Given: + An S3Cache and a key with a ``..`` segment. + When: + put is awaited with that key. + Then: + It should raise ValueError rather than write under a parent prefix. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"x") + + # Act & assert + with pytest.raises(ValueError): + await cache.put("../oops", source) + + @pytest.mark.asyncio + async def test_put_then_head_with_configured_prefix(self, s3_client, tmp_path): + """Test that the configured prefix is applied to every operation. + + Given: + An S3Cache configured with a non-empty prefix. + When: + put writes a key and head reads it back. + Then: + The object should land under ``/`` and head should + report its size correctly. + """ + # Arrange + cache = S3Cache(bucket=_BUCKET, prefix="env/dev", client=s3_client) + source = tmp_path / "src" + _write(source, b"abc") + + # Act + await cache.put("encode/x/data/aa-v0", source) + entry = await cache.head("encode/x/data/aa-v0") + + # Assert — moto stores under the prefixed key, not the raw key + assert entry == CacheEntry(key="encode/x/data/aa-v0", size=3) + listing = s3_client.list_objects_v2(Bucket=_BUCKET).get("Contents", []) + assert {item["Key"] for item in listing} == {"env/dev/encode/x/data/aa-v0"} From 7a038365b473d9a041104c4f5df5aa96f3eec675 Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 1 May 2026 13:51:01 -0400 Subject: [PATCH 3/7] feat: Add ECS provisioner and worker discovery EcsProvisioner is a thin boto3 RunTask wrapper that launches an ephemeral worker container per workflow, with awsvpc network configuration, concurrent-call dedup keyed on the workflow mutex key (so a burst of ensure_workflow calls on the same source doesn't fan out into multiple tasks), and a semaphore guarding the ~20 req/s RunTask rate limit. CapacityException covers both ClientError- and failures[].reason-shaped capacity / ENI errors so the executor can retry rather than hang. EcsDiscovery is a Wool DiscoveryLike poll-and-diff over list_tasks + describe_tasks: it filters on healthStatus HEALTHY, extracts the awsvpc IP, and emits worker-added / worker-dropped events to non- blocking subscribers via per-subscriber asyncio.Queue. State replay on subscribe means subscribers attached after startup observe the existing healthy fleet. --- src/cfdb/workflows/constants.py | 17 + src/cfdb/workflows/discovery.py | 523 +++++++++++++++++++++++ src/cfdb/workflows/provisioner.py | 387 +++++++++++++++++ tests/test_workflows/test_discovery.py | 343 +++++++++++++++ tests/test_workflows/test_provisioner.py | 398 +++++++++++++++++ 5 files changed, 1668 insertions(+) create mode 100644 src/cfdb/workflows/constants.py create mode 100644 src/cfdb/workflows/discovery.py create mode 100644 src/cfdb/workflows/provisioner.py create mode 100644 tests/test_workflows/test_discovery.py create mode 100644 tests/test_workflows/test_provisioner.py diff --git a/src/cfdb/workflows/constants.py b/src/cfdb/workflows/constants.py new file mode 100644 index 0000000..4d942eb --- /dev/null +++ b/src/cfdb/workflows/constants.py @@ -0,0 +1,17 @@ +"""Constants shared across cfdb.workflows.* modules. + +Lives in its own leaf module so ``discovery`` can read +``DEFAULT_WORKER_PORT`` without importing ``worker_main`` — importing +``worker_main`` would drag aiohttp into ``discovery``'s import graph +and impose a runtime dependency that none of ``discovery``'s actual +consumers (the API lifespan, tests) need. +""" + +from __future__ import annotations + +#: Wool gRPC port the worker container binds and clients dial. +#: ``EcsDiscovery`` uses it to assemble worker addresses; +#: ``worker_main`` uses it as the bind port. The two MUST agree, and +#: the ECS task definition's ``portMappings`` / ``healthCheck`` target +#: MUST match. +DEFAULT_WORKER_PORT = 50051 diff --git a/src/cfdb/workflows/discovery.py b/src/cfdb/workflows/discovery.py new file mode 100644 index 0000000..b160222 --- /dev/null +++ b/src/cfdb/workflows/discovery.py @@ -0,0 +1,523 @@ +"""Worker discovery driven by the ECS control plane. + +ECS already owns the lifecycle of every worker task — registration, IP, +status, health — so we use it as the discovery substrate directly rather +than maintaining a parallel registry. ``EcsDiscovery`` implements Wool's +``DiscoveryLike`` protocol with a poll-and-diff loop: + +1. ``ecs.list_tasks`` enumerates every task in the cluster matching the + worker task family with ``desiredStatus="RUNNING"``. +2. ``ecs.describe_tasks`` (batched 100 ARNs at a time) hydrates each + into its full state, including ``healthStatus`` and ``attachments``. +3. We filter for ``lastStatus == "RUNNING"`` and (where reported) + ``healthStatus == "HEALTHY"`` — tasks whose container definition + omits a ``healthCheck`` are accepted as healthy so a deployment that + hasn't yet wired up health checks still surfaces workers. +4. The poller diffs the resolved set against the previous one and + publishes ``worker-added`` / ``worker-dropped`` events to a Wool + ``DiscoveryPublisherLike``. + +The worker side does nothing for discovery — no heartbeat thread, no +Mongo write, no registration RPC. ECS reports ``healthStatus: HEALTHY`` +once the container's health check passes, which is the "ready to +dispatch" signal the discovery filters on. The worker task definition +MUST declare a ``healthCheck`` against the gRPC port; without one +ECS reports ``healthStatus: UNKNOWN`` indefinitely and the worker is +never advertised. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections.abc import AsyncIterator, Iterable +from typing import Any, Optional + +from wool.runtime.discovery.base import ( + Discovery, + DiscoveryEvent, + DiscoveryEventType, + DiscoveryPublisherLike, + DiscoverySubscriberLike, + PredicateFunction, + WorkerMetadata, +) + +from cfdb.workflows.constants import DEFAULT_WORKER_PORT +from cfdb.workflows.provisioner import build_ecs_client + +logger = logging.getLogger(__name__) + +#: ECS ``DescribeTasks`` accepts up to 100 ARNs per call. +_DESCRIBE_BATCH_SIZE = 100 + +#: Default poll cadence. ECS ``ListTasks`` is a Cluster resource read +#: action: burst 100, sustained 20 req/s per region (shared bucket with +#: ``DescribeTasks``). 5s leaves ample headroom for many concurrent +#: clusters. +DEFAULT_POLL_INTERVAL_SECONDS = 5.0 + + +class EcsDiscovery(Discovery): + """Wool-compatible discovery service backed by ECS list/describe. + + Args: + cluster: ECS cluster name or ARN to poll. + task_definition_family: Worker task definition family (sans + revision). Used as the ``family`` filter on ``ListTasks``. + client: Optional pre-built boto3 ``ecs`` client. When omitted, + one is constructed via :func:`build_ecs_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`build_ecs_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here. + region_name: Boto3 ``region_name``. Plumbed analogously. + poll_interval: Seconds between successive ``ListTasks`` polls. + worker_port: gRPC port the worker binds — used to construct + the address string passed to Wool. Shares the worker_main + default so a deployment that changes one changes both. + version: Wool worker version string passed through into + ``WorkerMetadata``. Free-form; useful for filtering. + + Lifecycle: enter ``async with EcsDiscovery(...)`` to start the + background poller. Exiting the context cancels the poller and + releases the publisher. + """ + + def __init__( + self, + *, + cluster: str, + task_definition_family: str, + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + worker_port: int = DEFAULT_WORKER_PORT, + version: str = "0", + ) -> None: + if not cluster: + raise ValueError("EcsDiscovery requires a cluster name") + if not task_definition_family: + raise ValueError("EcsDiscovery requires a task_definition_family") + if poll_interval <= 0: + raise ValueError("poll_interval must be positive") + + self._cluster = cluster + self._task_definition_family = task_definition_family + self._client = ( + client + if client is not None + else build_ecs_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._poll_interval = poll_interval + self._worker_port = worker_port + self._version = version + + self._subscribers: list[_EcsSubscriber] = [] + self._known: dict[str, WorkerMetadata] = {} + self._poll_task: Optional[asyncio.Task[None]] = None + self._lock = asyncio.Lock() + # Separate from ``_lock`` so subscriber registration is not + # blocked on the AWS round-trip held by an in-flight poll. + # Serializes the entire ``poll_once`` body so two concurrent + # callers (test fixture + background loop, mostly) cannot + # interleave their list/describe snapshots into the diff and + # regress ``_known``. + self._poll_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + async def __aenter__(self) -> "EcsDiscovery": + # Re-entry would orphan the prior poll task and diff against + # stale ``_known`` state. ``__aexit__`` nulls ``_poll_task``; + # the assert pairs with that to make the misuse loud. + assert self._poll_task is None, ( + "EcsDiscovery context entered twice; create a fresh instance" + ) + # Run an immediate scan so subscribers attached right after + # __aenter__ see the initial set of workers without waiting a + # full poll interval. + await self.poll_once() + self._poll_task = asyncio.create_task(self._poll_loop()) + return self + + async def __aexit__(self, *_exc: Any) -> None: + if self._poll_task is not None: + self._poll_task.cancel() + try: + await self._poll_task + except asyncio.CancelledError: + pass + self._poll_task = None + # Wake any consumer parked on ``await self._queue.get()`` so + # ``async for event in subscriber:`` exits cleanly rather than + # blocking forever on a queue nothing will publish to again. + # The sentinel ``None`` traverses each subscriber's queue and + # ``_iter`` breaks its loop on it. + async with self._lock: + for sub in self._subscribers: + await sub._queue.put(None) + # Clear cached state so a stray re-subscribe after exit + # doesn't replay stale workers. + self._known.clear() + self._subscribers.clear() + + # ------------------------------------------------------------------ + # Wool DiscoveryLike protocol + # ------------------------------------------------------------------ + @property + def publisher(self) -> DiscoveryPublisherLike: + """``EcsDiscovery`` is read-only — workers register implicitly via ECS. + + Returns a guard-rail publisher that raises on ``publish``; + production code should never reach for this property. + """ + return _RaisingPublisher() + + @property + def subscriber(self) -> DiscoverySubscriberLike: + """A subscriber with no filter (sees every healthy worker).""" + return self.subscribe() + + def subscribe( + self, filter: Optional[PredicateFunction] = None + ) -> DiscoverySubscriberLike: + """Create a subscriber, optionally filtered by a worker predicate. + + The filter predicate runs inline on the dispatch path while the + registration lock is held — keep it cheap and non-blocking. + Slow predicates stall delivery to every other subscriber and + block concurrent registrations. + """ + return _EcsSubscriber(self, filter) + + # ------------------------------------------------------------------ + # Polling + # ------------------------------------------------------------------ + async def poll_once(self) -> tuple[list[DiscoveryEvent], dict[str, WorkerMetadata]]: + """Run one list/describe cycle and emit diff events. + + Returns the events emitted in this cycle and the resolved set of + currently-known healthy workers (keyed by UUID hex string). + Callers do not normally need to use the return value — it + exists for tests so that polling can be exercised step-by-step + without the background loop. + + ``self._poll_lock`` serializes the entire body so two concurrent + callers (test fixture + background loop) cannot interleave + their list/describe snapshots into the diff and regress + ``_known``. ``self._lock`` separately serializes the + diff/mutate/dispatch tail against subscriber registration; see + ``_register_subscriber`` for the pairing. + + On ``list_tasks`` / ``describe_tasks`` failure the previous + ``_known`` snapshot is retained — workers added during a multi- + cycle ECS outage are invisible until recovery, and workers + terminated during the outage continue to be advertised as + healthy until the next successful poll. Acceptable degradation + next to the cold-start budget; preferable to flapping the whole + fleet on a single transient failure. + """ + async with self._poll_lock: + try: + arns = await asyncio.to_thread(self._list_task_arns) + except Exception: + logger.exception( + "EcsDiscovery: list_tasks failed for cluster=%s", + self._cluster, + ) + async with self._lock: + return [], dict(self._known) + + resolved: dict[str, WorkerMetadata] = {} + if arns: + try: + tasks = await asyncio.to_thread(self._describe_tasks_batched, arns) + except Exception: + logger.exception( + "EcsDiscovery: describe_tasks failed for cluster=%s", + self._cluster, + ) + async with self._lock: + return [], dict(self._known) + for task in tasks: + metadata = self._task_to_metadata(task) + if metadata is not None: + resolved[str(metadata.uid)] = metadata + + # Diff, mutate, and dispatch under ``self._lock`` so a + # concurrent ``_register_subscriber`` cannot replay a + # snapshot that already contains a freshly-added worker AND + # then receive the worker-added event for it from this + # dispatch (duplicate delivery). Either the registration + # runs entirely before this mutation (replay sees the OLD + # known, this dispatch sees the new subscriber) or entirely + # after (replay sees the NEW known, this dispatch's events + # were already delivered). + async with self._lock: + events = list(_diff(self._known, resolved)) + self._known = resolved + if events: + for sub in self._subscribers: + for event in events: + await sub._publish(event) + + return events, dict(resolved) + + async def _poll_loop(self) -> None: + """Background poll loop. Cancellation-safe.""" + while True: + try: + await asyncio.sleep(self._poll_interval) + await self.poll_once() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("EcsDiscovery poll loop iteration failed") + + # ------------------------------------------------------------------ + # ECS API plumbing + # ------------------------------------------------------------------ + def _list_task_arns(self) -> list[str]: + """Page through ``ListTasks`` and return every task ARN.""" + arns: list[str] = [] + paginator_kwargs = { + "cluster": self._cluster, + "family": self._task_definition_family, + "desiredStatus": "RUNNING", + } + next_token: Optional[str] = None + while True: + kwargs = dict(paginator_kwargs) + if next_token: + kwargs["nextToken"] = next_token + response = self._client.list_tasks(**kwargs) + arns.extend(response.get("taskArns") or []) + next_token = response.get("nextToken") + if not next_token: + break + return arns + + def _describe_tasks_batched(self, arns: list[str]) -> list[dict[str, Any]]: + """Call ``DescribeTasks`` in batches of ``_DESCRIBE_BATCH_SIZE``.""" + out: list[dict[str, Any]] = [] + for i in range(0, len(arns), _DESCRIBE_BATCH_SIZE): + batch = arns[i : i + _DESCRIBE_BATCH_SIZE] + response = self._client.describe_tasks( + cluster=self._cluster, + tasks=batch, + ) + out.extend(response.get("tasks") or []) + return out + + def _task_to_metadata(self, task: dict[str, Any]) -> Optional[WorkerMetadata]: + """Convert an ECS task description to a Wool ``WorkerMetadata``. + + Returns None when the task is not RUNNING + HEALTHY or we + cannot extract a usable IP address. Worker task definitions + MUST declare a ``healthCheck``; tasks without one surface as + ``healthStatus: UNKNOWN`` and are filtered out, since their + gRPC readiness is unknowable. Tasks whose ``healthStatus`` is + absent from the describe-tasks response are also filtered — + same reason. + """ + if task.get("lastStatus") != "RUNNING": + return None + if task.get("healthStatus") != "HEALTHY": + return None + + ip = _extract_eni_ip(task) + if ip is None: + return None + + # ECS task ARNs end in ``/`` where ```` is a + # 32-char hex string. We use it as a stable UUID for the + # worker so successive polls produce identical metadata + # (otherwise diff would emit add+drop on every cycle). A + # missing ARN means we cannot identify the task — drop it + # rather than collide every ARN-less task on the same UUID. + task_arn = task.get("taskArn") or "" + if not task_arn: + return None + task_id = task_arn.rsplit("/", 1)[-1] + try: + uid = uuid.UUID(task_id) + except (ValueError, AttributeError): + uid = uuid.uuid5(uuid.NAMESPACE_URL, task_arn) + + return WorkerMetadata( + uid=uid, + address=f"{ip}:{self._worker_port}", + pid=0, + version=self._version, + ) + + async def _register_subscriber(self, sub: "_EcsSubscriber") -> None: + """Append a subscriber and prime its queue with the current snapshot. + + Both steps happen under ``self._lock``, paired with + ``poll_once`` mutating ``_known`` and dispatching under the + same lock: a concurrent poll either runs entirely before the + registration (replay sees the OLD ``_known``, the poll's + events go to the new subscriber via dispatch) or entirely + after (replay sees the NEW ``_known``, no dispatch overlaps). + No interleave produces a missed event or a duplicate ``worker- + added`` for the same UID. + """ + async with self._lock: + self._subscribers.append(sub) + for metadata in self._known.values(): + await sub._publish(DiscoveryEvent("worker-added", metadata=metadata)) + + async def _unregister_subscriber(self, sub: "_EcsSubscriber") -> None: + """Remove ``sub`` from the subscriber list under the lock. + + Silently ignores already-removed subscribers so the iteration + ``finally`` is idempotent. + """ + async with self._lock: + try: + self._subscribers.remove(sub) + except ValueError: + pass + + +class _EcsSubscriber: + """Discovery subscriber backed by an ``asyncio.Queue``. + + The discovery instance pushes events into the queue; consumers + iterate via ``async for``. Subscribers are single-use — a second + iteration after the first has finished raises ``RuntimeError``; + create a fresh subscriber via ``EcsDiscovery.subscribe()`` instead. + """ + + def __init__( + self, + owner: EcsDiscovery, + filter: Optional[PredicateFunction], + ) -> None: + self._owner = owner + self._filter = filter + # The queue carries discovery events plus a single ``None`` + # sentinel pushed by ``EcsDiscovery.__aexit__`` to wake a + # consumer that would otherwise block forever on ``get()`` + # after the discovery has shut down. + self._queue: asyncio.Queue[Optional[DiscoveryEvent]] = asyncio.Queue() + self._exhausted = False + + def __aiter__(self) -> AsyncIterator[DiscoveryEvent]: + return self._iter() + + async def _publish(self, event: DiscoveryEvent) -> None: + """Deliver ``event`` to the subscriber's queue, or drop it. + + Called by ``EcsDiscovery.poll_once`` while holding the + discovery lock. The filter (if any) runs inline; events that + don't pass it are dropped silently — the subscriber asked not + to see this worker. Otherwise the event is enqueued on the + subscriber's unbounded ``asyncio.Queue`` and surfaces via + ``_iter`` on the next consumer ``__anext__``. + """ + if self._filter is not None and not self._filter(event.metadata): + return + await self._queue.put(event) + + async def _iter(self) -> AsyncIterator[DiscoveryEvent]: + """Yield discovery events for the lifetime of this subscriber. + + Single-use: a second call after the iterator has been consumed + raises ``RuntimeError``. ``EcsDiscovery.subscribe()`` returns + a fresh subscriber per call, so multiple independent consumers + are supported by creating multiple subscribers. + + Registers with the owning discovery on the first iteration so + the snapshot of currently-known workers is replayed before + any new poll events arrive. The ``finally`` always + unregisters — ``_unregister_subscriber`` silently ignores + already-removed entries, so a partial-registration failure + does not leak a stale subscriber. + + The ``_exhausted`` flag is flipped synchronously *before* the + first ``await`` so two ``__aiter__()`` calls on the same + subscriber (e.g. a framework that calls ``aiter()`` for + introspection before iterating) cannot both register and + share the underlying queue. CPython's single-threaded asyncio + guarantees the first generator's body runs up to its first + await before any other generator starts; the second + generator's check fires on its first ``__anext__``. + """ + if self._exhausted: + raise RuntimeError( + "EcsDiscovery subscriber already iterated; " + "call EcsDiscovery.subscribe() for a fresh one" + ) + self._exhausted = True + try: + await self._owner._register_subscriber(self) + while True: + event = await self._queue.get() + if event is None: + # Sentinel from ``EcsDiscovery.__aexit__``; exit + # cleanly so the consumer's ``async for`` ends. + break + yield event + finally: + await self._owner._unregister_subscriber(self) + + +class _RaisingPublisher: + """Guard-rail publisher returned by ``EcsDiscovery.publisher``. + + ECS owns worker lifecycle, so nothing in cfdb has cause to call + ``publish``; raising here surfaces the misuse loudly rather than + silently no-op'ing. + """ + + async def publish( + self, type: DiscoveryEventType, metadata: WorkerMetadata + ) -> None: + raise RuntimeError( + "EcsDiscovery is read-only — workers register implicitly via ECS" + ) + + +def _extract_eni_ip(task: dict[str, Any]) -> Optional[str]: + """Return the awsvpc private IPv4 from an ECS task description. + + ECS attaches one ENI per Fargate awsvpc task; its ``details`` list + carries a ``privateIPv4Address`` entry. Older API responses use + ``networkInterfaces`` on each container instead — we honor either. + When both forms are populated the ``attachments`` value wins; for + Fargate awsvpc both report the same IP, so the precedence is + moot in practice. + """ + for attachment in task.get("attachments") or []: + if attachment.get("type") not in (None, "ElasticNetworkInterface"): + continue + for detail in attachment.get("details") or []: + if detail.get("name") == "privateIPv4Address" and detail.get("value"): + return detail["value"] + for container in task.get("containers") or []: + for nic in container.get("networkInterfaces") or []: + ipv4 = nic.get("privateIpv4Address") + if ipv4: + return ipv4 + return None + + +def _diff( + cached: dict[str, WorkerMetadata], + discovered: dict[str, WorkerMetadata], +) -> Iterable[DiscoveryEvent]: + """Yield Wool events describing the cached→discovered transition.""" + for uid, metadata in discovered.items(): + if uid not in cached: + yield DiscoveryEvent("worker-added", metadata=metadata) + elif cached[uid].address != metadata.address: + yield DiscoveryEvent("worker-updated", metadata=metadata) + for uid, metadata in cached.items(): + if uid not in discovered: + yield DiscoveryEvent("worker-dropped", metadata=metadata) diff --git a/src/cfdb/workflows/provisioner.py b/src/cfdb/workflows/provisioner.py new file mode 100644 index 0000000..a71ec58 --- /dev/null +++ b/src/cfdb/workflows/provisioner.py @@ -0,0 +1,387 @@ +"""ECS on-demand worker provisioning via the ``RunTask`` API. + +``EcsProvisioner`` wraps a single boto3 ``RunTask`` call with the cluster, +task-definition family, and awsvpc networking configuration the workflow +subsystem needs. The same code targets LocalStack in development and real +AWS in production — only ``AWS_ENDPOINT_URL`` differs at the boto3 client. + +Concurrent ``request`` calls sharing the same ``dedup_key`` (typically +the workflow key) attach to a single in-flight ``RunTask`` task and +observe the same outcome. ``asyncio.shield`` insulates each caller's +cancellation from the others, so a request abandoned by one caller does +not poison the result for the rest. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Iterable +from typing import Any, Literal, Optional, get_args + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError + +logger = logging.getLogger(__name__) + + +#: Acceptable values for ``RunTask`` ``assignPublicIp`` — ECS rejects +#: anything else. The runtime-validation set is derived from the +#: ``Literal`` so the two definitions can't drift. +AssignPublicIp = Literal["ENABLED", "DISABLED"] +_ASSIGN_PUBLIC_IP_VALUES = frozenset(get_args(AssignPublicIp)) + + +class RetryableProvisionerError(RuntimeError): + """Provisioner failure the caller should resubmit later. + + Covers both ECS-side capacity / ENI exhaustion and transport-level + transients (connection timeouts, endpoint unavailability, + throttling). The executor's response is identical for both — mark + the workflow ``FAILED`` with a retryable error string — so they + share one exception type. + """ + + +class EcsProvisioner: + """Launch ephemeral worker tasks on ECS Fargate. + + Args: + cluster: ECS cluster name or ARN. + task_definition: Task definition family (or family:revision) + for the worker container. + subnets: Awsvpc subnet IDs to place worker ENIs into. + security_groups: Awsvpc security group IDs to attach. + assign_public_ip: ``"ENABLED"`` or ``"DISABLED"`` — whether the + ENI gets a public IP. Production should usually leave this + disabled and rely on VPC endpoints; LocalStack accepts + either value. + client: Optional pre-built boto3 ``ecs`` client. When omitted, + one is constructed via :func:`build_ecs_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`build_ecs_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here. + region_name: Boto3 ``region_name``. Plumbed analogously. + max_in_flight: Soft cap on concurrent ``RunTask`` calls. ECS's + ``RunTask`` API is rate-limited to ~20 req/s per account; + this guard keeps us well under it. + """ + + def __init__( + self, + *, + cluster: str, + task_definition: str, + subnets: Iterable[str], + security_groups: Iterable[str] = (), + assign_public_ip: AssignPublicIp = "DISABLED", + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + max_in_flight: int = 16, + ) -> None: + if not cluster: + raise ValueError("EcsProvisioner requires a cluster name") + if not task_definition: + raise ValueError("EcsProvisioner requires a task_definition") + subnet_list = list(subnets) + if not subnet_list: + raise ValueError("EcsProvisioner requires at least one subnet") + if assign_public_ip not in _ASSIGN_PUBLIC_IP_VALUES: + raise ValueError( + f"assign_public_ip must be one of {sorted(_ASSIGN_PUBLIC_IP_VALUES)}; " + f"got {assign_public_ip!r}" + ) + + self._cluster = cluster + self._task_definition = task_definition + self._subnets = subnet_list + self._security_groups = list(security_groups) + self._assign_public_ip = assign_public_ip + self._client = ( + client + if client is not None + else build_ecs_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._semaphore = asyncio.Semaphore(max_in_flight) + # Concurrent ``request`` calls sharing a key attach to the + # same in-flight Task. ``_run_task_owned``'s ``finally`` block + # always pops the entry on completion (success or failure) so + # a fresh request after the previous one finishes spawns a new + # launch. + self._in_flight: dict[str, asyncio.Task[list[str]]] = {} + self._in_flight_lock = asyncio.Lock() + + async def request(self, *, dedup_key: str) -> list[str]: + """Launch a worker task, returning its ARN(s). + + Concurrent callers sharing ``dedup_key`` share one ``RunTask``: + only the first launches; the rest await the same result. + + Args: + dedup_key: Typically the workflow mutex key. Two callers + holding the same workflow-level mutex should not + independently launch two workers. + + Returns: + List of task ARNs corresponding to the launched worker(s). + + Raises: + RetryableProvisionerError: Capacity / ENI exhaustion, + throttling, or connection-level transients. Callers + should mark the workflow as failed with a retryable + status so the client can resubmit. + """ + if not dedup_key: + raise ValueError("EcsProvisioner.request requires a non-empty dedup_key") + + async with self._in_flight_lock: + existing = self._in_flight.get(dedup_key) + # Self-heal: if the cached task already finished and the + # finally-block release was skipped (CancelledError during + # event-loop teardown could re-raise out of the awaited + # shielded pop before the inner coroutine ran), treat the + # stale entry as absent and launch a fresh task. This + # collapses every "the pop never ran" cancellation race + # into one self-healing read at registration time. + if existing is not None and existing.done(): + existing = None + self._in_flight.pop(dedup_key, None) + if existing is None: + existing = asyncio.create_task( + self._run_task_owned(dedup_key), + name=f"ecs-runtask-{dedup_key}", + ) + # Retrieve the eventual exception/result so a Task + # abandoned by every caller (universal cancellation) + # doesn't trigger asyncio's "Task exception was never + # retrieved" warning when it's garbage-collected. + existing.add_done_callback(_consume_task_outcome) + self._in_flight[dedup_key] = existing + + # Shield protects concurrent callers sharing this dedup_key + # from each other's cancellation: if one caller is cancelled + # mid-await, only that caller sees CancelledError; the + # underlying task continues and other callers still observe + # the real result. + return await asyncio.shield(existing) + + async def _run_task_owned(self, dedup_key: str) -> list[str]: + """Body of the dedup-protected RunTask call. + + Split out from ``request`` so the dedup-registration lock is + not held across the boto3 thread-pool round-trip. The + ``finally`` always pops the in-flight slot so the next + ``request`` for the same key can launch a fresh worker. The + pop is shielded so a CancelledError delivered to the task + itself (event-loop teardown, explicit cancel) cannot skip the + cleanup and leak a stale dedup entry pointing at a cancelled + task. + """ + + async def _release_dedup_slot() -> None: + async with self._in_flight_lock: + self._in_flight.pop(dedup_key, None) + + try: + async with self._semaphore: + return await self._run_task() + finally: + await asyncio.shield(_release_dedup_slot()) + + async def _run_task(self) -> list[str]: + """Single ``RunTask`` invocation translated to a list of task ARNs. + + ``count`` is hardcoded to 1 because the dedup contract is one + worker per workflow key: two concurrent ``request(dedup_key=K)`` + calls share the same task, and once that task finishes a fresh + request for the same key launches a fresh worker. The + failure-vs-success branching below relies on this — any entry + in ``response["failures"]`` means the launch did not happen, so + the call raises rather than returning a mix. + """ + kwargs: dict[str, Any] = { + "cluster": self._cluster, + "taskDefinition": self._task_definition, + "launchType": "FARGATE", + "count": 1, + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": list(self._subnets), + "securityGroups": list(self._security_groups), + "assignPublicIp": self._assign_public_ip, + } + }, + } + + # ``ClientError`` covers HTTP-level ECS errors with structured + # response codes; ``BotoCoreError`` covers transport / cred + # failures with no ``.response``. Both should be classified + # together so transient connection issues surface as retryable. + try: + response = await asyncio.to_thread(self._client.run_task, **kwargs) + except (ClientError, BotoCoreError) as exc: + if _is_retryable_error(exc): + raise RetryableProvisionerError( + f"{type(exc).__name__}: {exc}" + ) from exc + raise + + failures = response.get("failures") or [] + tasks = response.get("tasks") or [] + arns = [t["taskArn"] for t in tasks if t.get("taskArn")] + + if failures and not arns: + # No ARN to fall back on — the launch did not happen. + reasons = ", ".join(f.get("reason", "?") for f in failures) + if any(_is_retryable_failure(f) for f in failures): + raise RetryableProvisionerError( + f"ECS RunTask failures: {reasons}" + ) + raise RuntimeError(f"ECS RunTask failures: {reasons}") + + if failures and arns: + # Partial-success edge: ECS occasionally surfaces a + # secondary placement warning alongside a successfully + # launched task. Log and keep the ARN rather than + # discarding a worker that's already running (and + # billing) to chase the warning. + reasons = ", ".join(f.get("reason", "?") for f in failures) + logger.warning( + "ECS RunTask succeeded with %d task(s) but reported failures: %s", + len(arns), + reasons, + ) + + if not arns: + # ECS returned neither a launched task nor a failure entry. + # Treat as a retryable transient — the caller has nothing + # to dispatch to and the executor's failed-with-retryable + # path is the right response. + raise RetryableProvisionerError( + "ECS RunTask returned no tasks and no failures" + ) + logger.info( + "ECS RunTask launched %d task(s) on cluster=%s family=%s", + len(arns), + self._cluster, + self._task_definition, + ) + return arns + + async def aclose(self) -> None: + """Cancel every in-flight ``RunTask`` and clear the dedup map. + + Called from the API lifespan's ``finally`` so a shutdown while + a ``RunTask`` round-trip is mid-flight doesn't leave the task + unrequested-but-billed. ``gather(return_exceptions=True)`` + absorbs the ``CancelledError`` and any ``RetryableProvisionerError`` + already in flight; the done-callback handles "Task exception + never retrieved" warnings for the cases we don't await here. + """ + async with self._in_flight_lock: + tasks = list(self._in_flight.values()) + self._in_flight.clear() + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +def build_ecs_client( + *, endpoint_url: Optional[str] = None, region_name: Optional[str] = None +) -> Any: + """Construct a boto3 ``ecs`` client with explicit endpoint/region. + + The caller (typically :class:`EcsProvisioner`, :class:`EcsDiscovery`, + or the API lifespan) is the single source of truth for + ``endpoint_url`` and ``region_name`` — we no longer reach into + :mod:`cfdb.api` for fallback values. Pass + :data:`cfdb.api.AWS_ENDPOINT_URL` / :data:`cfdb.api.AWS_REGION` + from the lifespan; leave both ``None`` to let boto3's default + session resolver chain pick them up from the environment. + """ + return boto3.client( + "ecs", + endpoint_url=endpoint_url, + region_name=region_name, + ) + + +# Codes ECS returns as ``ClientError.response["Error"]["Code"]`` for a +# ``RunTask`` failure that a retry can plausibly fix. Capacity-style +# placement reasons (``RESOURCE:ENI`` / ``RESOURCE:CPU`` / +# ``RESOURCE:MEMORY`` / ``AWS.ECS.PlacementError``) are NOT in this +# set because ECS surfaces them via ``failures[].reason``, not as a +# top-level error code; ``_RETRYABLE_REASON_TOKENS`` catches them on +# the failure-payload path. +_RETRYABLE_ERROR_CODES = frozenset( + { + "Capacity", + "CapacityProviderException", + "ClusterCapacityProviderException", + "ThrottlingException", + "Throttling", + "RequestLimitExceeded", + "ServerException", + "ServiceUnavailableException", + "ServiceUnavailable", + } +) +# Substring match (``token in reason``). ``THROTTL`` is deliberately +# truncated so it covers any throttling-derived reason — ``THROTTLED``, +# ``THROTTLING``, ``THROTTLINGEXCEPTION`` — without enumerating every +# variant AWS might emit. +_RETRYABLE_REASON_TOKENS = ("CAPACITY", "RESOURCE:", "THROTTL") + + +def _is_retryable_error(exc: BaseException) -> bool: + """Return True when an ECS exception is a retryable transient. + + Handles both ``ClientError`` (with a structured response dict) and + ``BotoCoreError`` subclasses (transport-level failures without a + response) — connection timeouts, endpoint unavailability, and + other transport transients are exactly the cases the caller should + resubmit. + """ + response = getattr(exc, "response", None) + if isinstance(response, dict): + code = (response.get("Error") or {}).get("Code") + if code in _RETRYABLE_ERROR_CODES: + return True + # BotoCoreError subclasses (EndpointConnectionError, + # ConnectTimeoutError, ReadTimeoutError, HTTPClientError, etc.) + # carry no .response; classify the whole family as retryable. + return isinstance(exc, BotoCoreError) + + +def _is_retryable_failure(failure: dict[str, Any]) -> bool: + """Return True when a ``RunTask`` ``failures`` entry is retryable.""" + reason = (failure.get("reason") or "").upper() + return any(token in reason for token in _RETRYABLE_REASON_TOKENS) + + +def _consume_task_outcome(task: asyncio.Task[Any]) -> None: + """Retrieve a Task's outcome so asyncio doesn't log "never retrieved". + + Attached as a done-callback on the in-flight provisioner Task so + that universal-cancellation (every caller cancels before the Task + completes) doesn't surface as a noisy unretrieved-exception + warning at GC time. ``RetryableProvisionerError`` is a routine + outcome under capacity / throttling pressure, so abandoned ones + are logged at DEBUG; unexpected exception types stay at WARNING. + """ + if task.cancelled(): + return + exc = task.exception() + if exc is None: + return + if isinstance(exc, RetryableProvisionerError): + logger.debug( + "ECS RunTask completed with abandoned retryable error: %r", exc + ) + else: + logger.warning( + "ECS RunTask completed with abandoned exception: %r", exc + ) diff --git a/tests/test_workflows/test_discovery.py b/tests/test_workflows/test_discovery.py new file mode 100644 index 0000000..471dbe9 --- /dev/null +++ b/tests/test_workflows/test_discovery.py @@ -0,0 +1,343 @@ +"""Tests for EcsDiscovery.""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import Any + +import pytest + +from cfdb.workflows.discovery import EcsDiscovery + + +def _task_arn(task_id: str | None = None) -> str: + """Build a stable ECS task ARN for tests.""" + return f"arn:aws:ecs:us-east-1:123:task/cluster/{task_id or uuid.uuid4().hex}" + + +def _running_task( + task_id: str, + *, + ip: str = "10.0.0.5", + health: str = "HEALTHY", + status: str = "RUNNING", +) -> dict[str, Any]: + """Construct a fake ECS DescribeTasks entry.""" + return { + "taskArn": _task_arn(task_id), + "lastStatus": status, + "healthStatus": health, + "attachments": [ + { + "type": "ElasticNetworkInterface", + "details": [ + {"name": "subnetId", "value": "subnet-1"}, + {"name": "privateIPv4Address", "value": ip}, + ], + } + ], + } + + +class _FakeEcsClient: + """Fake ECS client whose responses can be re-set per call.""" + + def __init__(self) -> None: + self.task_arns: list[str] = [] + self.tasks: list[dict[str, Any]] = [] + + def list_tasks(self, **_kwargs): + return {"taskArns": list(self.task_arns)} + + def describe_tasks(self, *, cluster: str, tasks: list[str]): + wanted = set(tasks) + return { + "tasks": [t for t in self.tasks if t["taskArn"] in wanted], + } + + +class TestEcsDiscovery: + def test___init___without_cluster(self): + """Test that EcsDiscovery rejects an empty cluster argument. + + Given: + An empty cluster name. + When: + EcsDiscovery is constructed. + Then: + It should raise ValueError so misconfigurations fail fast. + """ + # Act & assert + with pytest.raises(ValueError, match="cluster"): + EcsDiscovery( + cluster="", + task_definition_family="worker", + client=_FakeEcsClient(), + ) + + def test___init___without_task_definition_family(self): + """Test that EcsDiscovery rejects an empty task_definition_family. + + Given: + A cluster but no task_definition_family. + When: + EcsDiscovery is constructed. + Then: + It should raise ValueError to surface the misconfiguration. + """ + # Act & assert + with pytest.raises(ValueError, match="task_definition_family"): + EcsDiscovery( + cluster="c", + task_definition_family="", + client=_FakeEcsClient(), + ) + + @pytest.mark.asyncio + async def test_poll_once_with_initial_healthy_workers(self): + """Test that the first poll emits worker-added events. + + Given: + A fake ECS client returning one RUNNING + HEALTHY task. + When: + poll_once is awaited. + Then: + It should emit exactly one ``worker-added`` event whose metadata + carries the task's private IP and the configured worker port. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + # The fake's task ARN must align with the seed. + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + worker_port=4242, + ) + + # Act + events, resolved = await discovery.poll_once() + + # Assert + assert len(events) == 1 + assert events[0].type == "worker-added" + assert events[0].metadata.address == "10.0.0.5:4242" + assert len(resolved) == 1 + + @pytest.mark.asyncio + async def test_poll_once_with_unhealthy_task_filtered(self): + """Test that UNHEALTHY tasks never surface as workers. + + Given: + A fake ECS client whose task is RUNNING but UNHEALTHY. + When: + poll_once is awaited. + Then: + No events should be emitted and resolved should be empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id, health="UNHEALTHY")] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + + # Act + events, resolved = await discovery.poll_once() + + # Assert + assert events == [] + assert resolved == {} + + @pytest.mark.asyncio + async def test_poll_once_with_dropped_task_after_initial_seen(self): + """Test that a vanished task surfaces as worker-dropped. + + Given: + A fake ECS client that initially reports one task and then + reports an empty cluster on the next poll. + When: + poll_once is awaited twice. + Then: + The second poll should emit one worker-dropped event for the + previously-known task. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + await discovery.poll_once() + + # Act — second poll observes the cluster gone empty + client.task_arns = [] + client.tasks = [] + events, resolved = await discovery.poll_once() + + # Assert + assert len(events) == 1 + assert events[0].type == "worker-dropped" + assert resolved == {} + + @pytest.mark.asyncio + async def test_poll_once_with_idempotent_steady_state(self): + """Test that repeated polls of the same set emit no extra events. + + Given: + A fake ECS client that consistently returns the same task. + When: + poll_once is awaited twice. + Then: + The second poll should emit no events because the diff is empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + await discovery.poll_once() + + # Act + events, _ = await discovery.poll_once() + + # Assert + assert events == [] + + @pytest.mark.asyncio + async def test_subscribe_with_filter_excluding_worker(self): + """Test that a subscriber's filter suppresses non-matching events. + + Given: + A discovery instance with one healthy worker and a subscriber + whose filter rejects every metadata. + When: + poll_once is awaited. + Then: + The subscriber's queue should remain empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + sub = discovery.subscribe(filter=lambda _meta: False) + + async def _drain(): + async for event in sub: # pragma: no cover — should never iterate + return event + return None + + # Act — start subscriber, then poll + consumer = asyncio.create_task(_drain()) + await asyncio.sleep(0) # let consumer register + await discovery.poll_once() + await asyncio.sleep(0.05) + + # Assert + consumer.cancel() + try: + await consumer + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_subscribe_with_two_iter_calls_raises_on_second(self): + """Test that double-iteration of one subscriber is refused. + + Given: + A subscriber whose async iterator has already been opened. + When: + A second async-for loop tries to drive the same subscriber. + Then: + It should raise RuntimeError on the second __anext__ rather + than silently double-register and duplicate event delivery. + """ + # Arrange + client = _FakeEcsClient() + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + sub = discovery.subscribe() + iter_one = sub.__aiter__() + iter_two = sub.__aiter__() + # Start iter_one so it flips _exhausted before iter_two runs. + first_consumer = asyncio.create_task(iter_one.__anext__()) + await asyncio.sleep(0) + + # Act & assert + with pytest.raises(RuntimeError, match="already iterated"): + await iter_two.__anext__() + + # Cleanup + first_consumer.cancel() + try: + await first_consumer + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_aexit_wakes_parked_consumer(self): + """Test that __aexit__ unblocks consumers parked on queue.get(). + + Given: + A discovery context with a subscriber consuming events. + When: + The discovery context exits while the consumer is parked. + Then: + The consumer's async-for loop ends cleanly within a short + timeout rather than blocking on a queue that nothing will + publish to again. + """ + # Arrange + client = _FakeEcsClient() + events_seen: list[Any] = [] + + async def _consume(sub): + async for event in sub: + events_seen.append(event) + + # Act + async with EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) as discovery: + consumer = asyncio.create_task(_consume(discovery.subscribe())) + # Let the consumer register and park on queue.get(). + await asyncio.sleep(0.05) + # Exiting the context should push the sentinel; the consumer + # task ends shortly after. + await asyncio.wait_for(consumer, timeout=1.0) + + # Assert — the consumer ended cleanly without exception. + assert consumer.done() and consumer.exception() is None diff --git a/tests/test_workflows/test_provisioner.py b/tests/test_workflows/test_provisioner.py new file mode 100644 index 0000000..a44c37b --- /dev/null +++ b/tests/test_workflows/test_provisioner.py @@ -0,0 +1,398 @@ +"""Tests for EcsProvisioner.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from cfdb.workflows.provisioner import RetryableProvisionerError, EcsProvisioner + + +class _FakeEcsClient: + """In-memory ECS client recording RunTask calls for assertions.""" + + def __init__(self, *, response: dict | None = None, raise_on_call: Exception | None = None) -> None: + self.calls: list[dict] = [] + self._response = response or { + "tasks": [{"taskArn": "arn:aws:ecs:::task/cluster/abc"}], + "failures": [], + } + self._raise = raise_on_call + + def run_task(self, **kwargs): + self.calls.append(kwargs) + if self._raise is not None: + raise self._raise + return self._response + + +class _SimpleGatedClient(_FakeEcsClient): + """run_task blocks on a threading.Event until released by the test.""" + + def __init__(self) -> None: + super().__init__() + import threading + + self._gate = threading.Event() + self._call_started = threading.Event() + + def run_task(self, **kwargs): + self.calls.append(kwargs) + self._call_started.set() + # Block the worker thread until the test releases the gate. + self._gate.wait() + return self._response + + def release(self) -> None: + self._gate.set() + + +def _client_error(code: str) -> Exception: + """Construct a botocore ClientError with a structured response dict.""" + from botocore.exceptions import ClientError + + return ClientError( + {"Error": {"Code": code, "Message": "simulated"}, "ResponseMetadata": {"HTTPStatusCode": 500}}, + "RunTask", + ) + + +class TestEcsProvisioner: + def test___init___without_cluster(self): + """Test that constructing EcsProvisioner without a cluster fails fast. + + Given: + No cluster name. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError so misconfiguration is caught at boot. + """ + # Act & assert + with pytest.raises(ValueError, match="cluster"): + EcsProvisioner( + cluster="", + task_definition="worker", + subnets=["subnet-1"], + client=_FakeEcsClient(), + ) + + def test___init___without_task_definition(self): + """Test that omitting task_definition raises ValueError. + + Given: + A cluster but no task definition. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError to surface the misconfiguration. + """ + # Act & assert + with pytest.raises(ValueError, match="task_definition"): + EcsProvisioner( + cluster="c", + task_definition="", + subnets=["subnet-1"], + client=_FakeEcsClient(), + ) + + def test___init___without_subnets(self): + """Test that constructing without subnets raises ValueError. + + Given: + A cluster and task definition but no subnets. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError because awsvpc requires at least one. + """ + # Act & assert + with pytest.raises(ValueError, match="subnet"): + EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=[], + client=_FakeEcsClient(), + ) + + @pytest.mark.asyncio + async def test_request_with_single_caller(self): + """Test that request returns the launched task ARNs. + + Given: + A provisioner backed by a fake client returning one task ARN. + When: + request is awaited once. + Then: + It should return the list of ARNs and have invoked RunTask once + with the configured cluster, family, and awsvpc network config. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + security_groups=["sg-1"], + client=client, + ) + + # Act + arns = await provisioner.request(dedup_key="wf-1") + + # Assert + assert arns == ["arn:aws:ecs:::task/cluster/abc"] + assert len(client.calls) == 1 + call = client.calls[0] + assert call["cluster"] == "c" + assert call["taskDefinition"] == "worker" + assert call["launchType"] == "FARGATE" + awsvpc = call["networkConfiguration"]["awsvpcConfiguration"] + assert awsvpc["subnets"] == ["subnet-1"] + assert awsvpc["securityGroups"] == ["sg-1"] + + @pytest.mark.asyncio + async def test_request_with_concurrent_dedup_key_collisions(self): + """Test that concurrent calls with the same dedup_key share one RunTask. + + Given: + A provisioner whose underlying RunTask call is gated on a + threading.Event so the first caller stays in-flight. + When: + Two coroutines call request concurrently with the same dedup_key. + Then: + Only one RunTask invocation should be issued, and both callers + should observe identical ARNs. + """ + # Arrange + client = _SimpleGatedClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + first = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Wait until the first call is mid-flight before starting the + # second so the dedup map is populated. + await asyncio.to_thread(client._call_started.wait) + second = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Release the gated client so both callers can return. + client.release() + first_arns, second_arns = await asyncio.gather(first, second) + + # Assert + assert first_arns == second_arns == ["arn:aws:ecs:::task/cluster/abc"] + assert len(client.calls) == 1 + + @pytest.mark.asyncio + async def test_request_with_distinct_dedup_keys(self): + """Test that distinct dedup keys produce independent RunTask calls. + + Given: + A provisioner backed by a fake client. + When: + request is awaited twice with different dedup_keys. + Then: + Two RunTask invocations should be issued. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + await provisioner.request(dedup_key="wf-a") + await provisioner.request(dedup_key="wf-b") + + # Assert + assert len(client.calls) == 2 + + @pytest.mark.asyncio + async def test_request_with_capacity_client_error(self): + """Test that capacity ClientErrors map to RetryableProvisionerError. + + Given: + A provisioner whose underlying client raises a Capacity error. + When: + request is awaited. + Then: + It should raise RetryableProvisionerError so callers can surface it as a + retryable terminal failure. + """ + # Arrange + client = _FakeEcsClient(raise_on_call=_client_error("CapacityProviderException")) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act & assert + with pytest.raises(RetryableProvisionerError): + await provisioner.request(dedup_key="wf-1") + + @pytest.mark.asyncio + async def test_request_with_capacity_failure_in_response(self): + """Test that capacity failures inside the RunTask response also raise. + + Given: + A provisioner whose RunTask response carries a failures entry + whose reason includes RESOURCE: tokens. + When: + request is awaited. + Then: + It should raise RetryableProvisionerError rather than treating the call + as a success with zero ARNs. + """ + # Arrange + client = _FakeEcsClient( + response={ + "tasks": [], + "failures": [{"reason": "RESOURCE:CPU"}], + } + ) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act & assert + with pytest.raises(RetryableProvisionerError): + await provisioner.request(dedup_key="wf-1") + + @pytest.mark.asyncio + async def test_request_with_partial_failure_preserves_arn(self): + """Test that a launched ARN is preserved even when failures[] is non-empty. + + Given: + A provisioner whose RunTask response carries both a launched + taskArn and a non-retryable failures entry. + When: + request is awaited. + Then: + It should return the launched ARN and log the failure rather + than discarding the worker that is already running. + """ + # Arrange + client = _FakeEcsClient( + response={ + "tasks": [{"taskArn": "arn:aws:ecs:::task/cluster/abc"}], + "failures": [{"reason": "secondary placement warning"}], + } + ) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + arns = await provisioner.request(dedup_key="wf-1") + + # Assert + assert arns == ["arn:aws:ecs:::task/cluster/abc"] + + @pytest.mark.asyncio + async def test_request_with_done_cached_task_launches_fresh_run(self): + """Test that a stale completed in-flight entry is replaced on next request. + + Given: + A provisioner where the first request completed and the dedup + slot still holds the done task (simulating a release-slot race + where the post-completion pop was skipped). + When: + A second request with the same dedup_key is awaited. + Then: + A fresh RunTask is issued rather than the second caller + attaching to the already-finished task. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + first = await provisioner.request(dedup_key="wf-shared") + # Simulate the post-completion pop having been skipped (e.g. by a + # CancelledError that re-raised out of the shielded release): put + # the done task back into the dedup map. + done_task = asyncio.create_task(_resolved(first)) + await done_task + provisioner._in_flight["wf-shared"] = done_task + + # Act + second = await provisioner.request(dedup_key="wf-shared") + + # Assert + assert second == first + assert len(client.calls) == 2 + + @pytest.mark.asyncio + async def test_aclose_cancels_in_flight_and_clears_map(self): + """Test that aclose cancels in-flight tasks and clears the dedup map. + + Given: + A provisioner with one in-flight RunTask whose underlying + boto3 call is mid-flight on the ``asyncio.to_thread`` worker. + When: + aclose is awaited. + Then: + The asyncio task is cancelled, the dedup map is empty, and + the original caller observes CancelledError rather than a hang. + """ + # Arrange — use a client that simulates a slow round-trip via a + # short ``time.sleep`` in the worker thread. ``asyncio.to_thread`` + # cannot cancel the running thread, but the wrapping task can be + # cancelled — which is what ``aclose`` does. The short sleep + # makes the thread exit promptly so pytest's executor join at + # session teardown does not hit a 300 s timeout. + import time + + class _SlowClient(_FakeEcsClient): + def run_task(self, **kwargs): + self.calls.append(kwargs) + time.sleep(0.5) + return self._response + + client = _SlowClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + caller = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Yield so the dedup-registration task is scheduled and the + # boto thread call has begun. + await asyncio.sleep(0.05) + + # Act + await provisioner.aclose() + + # Assert + assert provisioner._in_flight == {} + with pytest.raises(asyncio.CancelledError): + await caller + + +async def _resolved(value): + """Helper coroutine that immediately returns ``value``. + + Used by the dedup-self-heal test to construct a real done task that + holds the same return value as the first request. + """ + return value From 45e00da838dcf953196b2a76cf1608dee8ed2c42 Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 1 May 2026 13:51:16 -0400 Subject: [PATCH 4/7] feat: Add worker container entrypoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit worker_main.py is the entrypoint baked into the worker container image. It starts a wool.LocalWorker so the API can dispatch routines to the task, exposes a tiny aiohttp /health endpoint that returns 503 during drain — so ECS marks the task unhealthy before stop_task kills the gRPC port — and installs SIGTERM/SIGINT handlers so a stop_task issued by the API or the Fargate scheduler shuts down cleanly. The idle-shutdown timeout is configurable so tasks self-terminate when their workflow completes. The entrypoint deliberately does not register itself with discovery; EcsDiscovery polls ECS directly and surfaces healthy tasks to the pool, so the worker only needs to be listening and HEALTHY. --- src/cfdb/workflows/worker_main.py | 292 +++++++++++++++++++++++ tests/test_workflows/test_worker_main.py | 133 +++++++++++ 2 files changed, 425 insertions(+) create mode 100644 src/cfdb/workflows/worker_main.py create mode 100644 tests/test_workflows/test_worker_main.py diff --git a/src/cfdb/workflows/worker_main.py b/src/cfdb/workflows/worker_main.py new file mode 100644 index 0000000..4fb58f2 --- /dev/null +++ b/src/cfdb/workflows/worker_main.py @@ -0,0 +1,292 @@ +"""ECS Fargate worker entrypoint. + +This module is the ``CMD`` for the worker container image. It boots a +``wool.LocalWorker`` on a known port, exposes a tiny HTTP health endpoint +the ECS health check probes, handles SIGTERM cleanly so ``ecs.stop_task`` +cycles drain in-flight work, and self-terminates after a configurable +maximum lifetime so workers don't accumulate when the dispatch rate falls. + +No discovery registration code lives here. ``EcsDiscovery`` polls ECS's +own task state to surface running workers; the worker only needs to +bind its gRPC port and respond ``200 OK`` to the health probe. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +from collections.abc import Callable +from typing import TYPE_CHECKING, Optional + +import click +import wool + +from cfdb.workflows.constants import DEFAULT_WORKER_PORT + +if TYPE_CHECKING: + from aiohttp import web + +logger = logging.getLogger(__name__) + +__all__ = ["main", "serve"] + +#: Default health probe HTTP port — distinct from the gRPC port so +#: ``healthCheck`` can ``curl`` it without speaking gRPC. +DEFAULT_HEALTH_PORT = 8080 + +#: Default maximum wall-clock lifetime of a worker process. Wool exposes +#: no per-job activity hook today, so the worker can't tell idle from +#: busy; this is a hard ceiling — ECS replaces the task after this long. +#: Sized one hour above :data:`cfdb.workflows.WORKFLOW_DURATION_CAP_S` +#: (default 4 h) so a worker started shortly before a long sort can +#: still outlive the job. Note: max-lifetime expiry exits without the +#: drain-grace window the SIGTERM path provides — operators that want +#: a cleaner handoff should rely on ECS rolling tasks via service +#: updates rather than waiting for max-lifetime to fire. +DEFAULT_MAX_LIFETIME_SECONDS = 5 * 60 * 60 + +#: How long to keep returning ``503`` on ``/health`` after SIGTERM, +#: giving ECS a chance to observe ``unhealthy`` and drain at the load +#: balancer before we tear the gRPC port down. ECS health checks +#: default to ~30 s interval × 3 unhealthy retries = ~90 s worst case +#: to mark the task unhealthy; we hold for 120 s by default to leave +#: real margin even when the task definition uses ECS-default health +#: check cadence. Operators tightening ``healthCheck.interval`` to +#: ~10-15 s can lower this further via ``WORKER_DRAIN_GRACE_SECONDS``. +DEFAULT_DRAIN_GRACE_SECONDS = 120.0 + +#: Cadence of the main loop's wakeup when waiting on ``stop_event``. +#: One second balances precision (max-lifetime checks fire within ~1 s +#: of the target) against CPU churn over a multi-hour worker lifetime +#: (~3600 wakeups/h vs. ~3.6 M at 1 ms). Sub-second granularity is +#: unnecessary because the upstream coarse-grained signals +#: (SIGTERM-on-stop_task, ~hours-scale max-lifetime) don't need it. +_STOP_POLL_INTERVAL_SECONDS = 1.0 + + +async def serve( + *, + worker_port: int = DEFAULT_WORKER_PORT, + health_port: int = DEFAULT_HEALTH_PORT, + max_lifetime_seconds: float = DEFAULT_MAX_LIFETIME_SECONDS, + drain_grace_seconds: float = DEFAULT_DRAIN_GRACE_SECONDS, +) -> int: + """Run the worker until SIGTERM or maximum lifetime elapses. + + Returns ``0`` on clean shutdown (SIGTERM, SIGINT, or max-lifetime). + Bind failures and other early-startup errors raise out — ``main`` + propagates them and the process exits with a Python traceback, + which surfaces the cause in container logs more clearly than a + silent non-zero status. + + A second SIGTERM during drain short-circuits the grace window + (operator impatience or ECS escalating before SIGKILL); the worker + stops immediately. + """ + stop_event = asyncio.Event() + force_stop_event = asyncio.Event() + loop = asyncio.get_running_loop() + started_at = loop.time() + + def _signal_handler() -> None: + if stop_event.is_set(): + logger.info("Second termination signal — exiting drain immediately") + force_stop_event.set() + else: + logger.info("Received termination signal — draining worker") + stop_event.set() + + def _signal_handler_threaded(*_: object) -> None: + # Fallback for platforms (notably Windows in CI) where + # ``loop.add_signal_handler`` raises ``NotImplementedError``. + # Defined once so the closure isn't rebuilt per signal. + loop.call_soon_threadsafe(_signal_handler) + + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + signal.signal(sig, _signal_handler_threaded) + + health_runner = await _start_health_server( + health_port, lambda: stop_event.is_set() + ) + + worker = wool.LocalWorker(host="0.0.0.0", port=worker_port) + await worker.start() + try: + logger.info( + "Wool worker listening on port %d (health on %d)", + worker_port, + health_port, + ) + while True: + # Check the self-termination path first. Max-lifetime is a + # local hard ceiling — nothing upstream (ECS, LB) is + # waiting on us to drain a multi-second grace window, but + # flipping ``/health`` to 503 here costs nothing and gives + # ``EcsDiscovery`` (and any layered LB) a chance to mark + # the task unhealthy before ``worker.stop()`` actually + # closes the gRPC port in ``finally``. Without this flip + # the health endpoint reports 200 right up to the moment + # the aiohttp runner is torn down. + if ( + max_lifetime_seconds > 0 + and (loop.time() - started_at) >= max_lifetime_seconds + ): + logger.info( + "Max lifetime (%.0fs) reached — exiting", + max_lifetime_seconds, + ) + stop_event.set() + break + if stop_event.is_set(): + # Signal-initiated shutdown. Hold the gRPC port open + # while ECS observes 503 on /health and stops routing + # new dispatches. A second signal flips + # force_stop_event and we exit the wait early. + if drain_grace_seconds > 0: + logger.info( + "Draining for up to %.0fs before exiting", + drain_grace_seconds, + ) + try: + await asyncio.wait_for( + force_stop_event.wait(), + timeout=drain_grace_seconds, + ) + except asyncio.TimeoutError: + pass + break + try: + await asyncio.wait_for( + stop_event.wait(), timeout=_STOP_POLL_INTERVAL_SECONDS + ) + except asyncio.TimeoutError: + continue + return 0 + finally: + try: + await worker.stop() + except Exception: + logger.exception("worker.stop() failed during shutdown") + await _shutdown_health_server(health_runner) + + +async def _start_health_server( + port: int, draining: Callable[[], bool] +) -> "web.AppRunner": + """Start a tiny HTTP server returning ``200 OK`` on ``/health``. + + The container's ECS ``healthCheck`` runs ``curl`` against this + endpoint, so the response shape doesn't matter — only the status + code does. While ``draining`` returns True (i.e. we're shutting + down), the endpoint returns ``503`` so ECS can mark the task + unhealthy and the load balancer / discovery can drain it before + ``ecs.stop_task`` actually kills the gRPC port. + + Cleans up the partial runner on bind failure so the caller's + finally doesn't have a half-initialized object to deal with. + """ + from aiohttp import web + + async def _health(_: web.Request) -> web.Response: + if draining(): + return web.Response(status=503, text="draining") + return web.Response(status=200, text="ok") + + app = web.Application() + app.router.add_get("/health", _health) + runner = web.AppRunner(app) + await runner.setup() + try: + site = web.TCPSite(runner, host="0.0.0.0", port=port) + await site.start() + except Exception: + await runner.cleanup() + raise + return runner + + +async def _shutdown_health_server(runner: Optional["web.AppRunner"]) -> None: + """Tear down the aiohttp ``AppRunner`` started by ``_start_health_server``. + + Tolerates ``runner is None`` so the caller's ``finally`` can run + even when the health server failed to start in the first place. + Cleanup exceptions are logged but swallowed: the worker is already + on its way out, and surfacing a teardown error would mask the + upstream cause (whatever triggered the shutdown). + """ + if runner is None: + return + try: + await runner.cleanup() + except Exception: + logger.exception("health server cleanup failed") + + +@click.command() +@click.option( + "--worker-port", + type=click.IntRange(1, 65535), + envvar="CFDB_WORKER_GRPC_PORT", + default=DEFAULT_WORKER_PORT, + show_default=True, + help="gRPC port the wool worker binds.", +) +@click.option( + "--health-port", + type=click.IntRange(1, 65535), + envvar="CFDB_WORKER_HEALTH_PORT", + default=DEFAULT_HEALTH_PORT, + show_default=True, + help="HTTP port the ECS health-check endpoint binds.", +) +@click.option( + "--max-lifetime-seconds", + type=click.FloatRange(min=0), + envvar="CFDB_WORKER_MAX_LIFETIME_SECONDS", + default=DEFAULT_MAX_LIFETIME_SECONDS, + show_default=True, + help="Hard ceiling on worker uptime in seconds; 0 disables.", +) +@click.option( + "--drain-grace-seconds", + type=click.FloatRange(min=0), + envvar="CFDB_WORKER_DRAIN_GRACE_SECONDS", + default=DEFAULT_DRAIN_GRACE_SECONDS, + show_default=True, + help=( + "Seconds to keep returning 503 on /health after SIGTERM before " + "tearing down the gRPC port. A second SIGTERM short-circuits." + ), +) +def main( + worker_port: int, + health_port: int, + max_lifetime_seconds: float, + drain_grace_seconds: float, +) -> None: + """ECS Fargate worker entrypoint — invoked by the container CMD. + + Boots a wool gRPC worker, exposes /health for ECS to probe, and + self-terminates after the max-lifetime ceiling. SIGTERM begins a + drain grace window during which /health returns 503 so the load + balancer can drop the worker before the gRPC port closes. + """ + logging.basicConfig(level=logging.INFO) + raise SystemExit( + asyncio.run( + serve( + worker_port=worker_port, + health_port=health_port, + max_lifetime_seconds=max_lifetime_seconds, + drain_grace_seconds=drain_grace_seconds, + ) + ) + ) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/tests/test_workflows/test_worker_main.py b/tests/test_workflows/test_worker_main.py new file mode 100644 index 0000000..e70c3e8 --- /dev/null +++ b/tests/test_workflows/test_worker_main.py @@ -0,0 +1,133 @@ +"""Tests for the ECS worker entrypoint argument parsing.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from click.testing import CliRunner + +from cfdb.workflows import worker_main + + +def _invoke(args: list[str]) -> tuple[int, dict[str, object]]: + """Run ``worker_main.main`` with ``args``, capturing the ``serve`` kwargs. + + Returns ``(exit_code, captured_kwargs)``. The real ``serve`` is patched + out — the entrypoint always calls ``raise SystemExit(asyncio.run(serve(...)))``, + so swapping ``asyncio.run`` for ``lambda coro: coro.close() or 0`` plus + stubbing ``serve`` keeps the test from actually booting a worker. + """ + captured: dict[str, object] = {} + + async def _fake_serve(**kwargs: object) -> int: + captured.update(kwargs) + return 0 + + runner = CliRunner() + with patch.object(worker_main, "serve", _fake_serve): + result = runner.invoke(worker_main.main, args, standalone_mode=True) + return result.exit_code, captured + + +class TestMainCli: + def test_main_uses_documented_defaults_when_no_args_or_env(self, monkeypatch): + """Test that bare invocation surfaces the documented defaults. + + Given: + No CLI arguments and no overriding environment variables. + When: + ``worker_main.main`` is invoked. + Then: + ``serve`` should be called with the documented worker port, + health port, max lifetime, and drain-grace defaults so a + bare container ``CMD`` works. + """ + # Arrange — clear any env overrides so defaults apply + for var in ( + "CFDB_WORKER_GRPC_PORT", + "CFDB_WORKER_HEALTH_PORT", + "CFDB_WORKER_MAX_LIFETIME_SECONDS", + "CFDB_WORKER_DRAIN_GRACE_SECONDS", + ): + monkeypatch.delenv(var, raising=False) + + # Act + exit_code, captured = _invoke([]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == worker_main.DEFAULT_WORKER_PORT + assert captured["health_port"] == worker_main.DEFAULT_HEALTH_PORT + assert captured["max_lifetime_seconds"] == worker_main.DEFAULT_MAX_LIFETIME_SECONDS + assert captured["drain_grace_seconds"] == worker_main.DEFAULT_DRAIN_GRACE_SECONDS + + def test_main_with_env_overrides(self, monkeypatch): + """Test that environment variables override the defaults. + + Given: + ``CFDB_WORKER_GRPC_PORT`` and friends set to non-default values. + When: + ``worker_main.main`` is invoked with no CLI flags. + Then: + ``serve`` should receive the env-driven values. + """ + # Arrange + monkeypatch.setenv("CFDB_WORKER_GRPC_PORT", "60001") + monkeypatch.setenv("CFDB_WORKER_HEALTH_PORT", "9001") + monkeypatch.setenv("CFDB_WORKER_MAX_LIFETIME_SECONDS", "1800") + monkeypatch.setenv("CFDB_WORKER_DRAIN_GRACE_SECONDS", "10") + + # Act + exit_code, captured = _invoke([]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == 60001 + assert captured["health_port"] == 9001 + assert captured["max_lifetime_seconds"] == 1800.0 + assert captured["drain_grace_seconds"] == 10.0 + + def test_main_cli_flags_override_env_vars(self, monkeypatch): + """Test that CLI flags win over environment variables. + + Given: + ``CFDB_WORKER_GRPC_PORT`` set in the environment. + When: + ``worker_main.main`` is invoked with an explicit + ``--worker-port`` CLI flag. + Then: + ``serve`` should receive the CLI-supplied value. + """ + # Arrange + monkeypatch.setenv("CFDB_WORKER_GRPC_PORT", "60001") + + # Act + exit_code, captured = _invoke(["--worker-port", "55555"]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == 55555 + + def test_main_rejects_out_of_range_worker_port(self, monkeypatch): + """Test that a port outside [1, 65535] is rejected at parse time. + + Given: + A ``--worker-port`` value above 65535. + When: + ``worker_main.main`` is invoked. + Then: + Click should exit non-zero before reaching ``serve``, so the + failure surfaces clearly rather than as an opaque bind error + later. + """ + # Arrange + for var in ("CFDB_WORKER_GRPC_PORT",): + monkeypatch.delenv(var, raising=False) + + # Act + exit_code, captured = _invoke(["--worker-port", "99999"]) + + # Assert + assert exit_code != 0 + assert "worker_port" not in captured From b784fe7f52279d3b8fbb695eb360d2ffd9151219 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 20 May 2026 12:29:54 -0400 Subject: [PATCH 5/7] feat(workflows): Raise default duration cap to 4 h Multi-hour preprocessing runs (samtools sort + index on a multi-GB BAM, tabix on a large interval file) routinely exceed the previous 1200 s (20 min) default and trip the asyncio.timeout in _run_workflow. The new 14400 s (4 h) default sizes the cap for the real workload profile without requiring every operator to set CFDB_WORKFLOW_DURATION_CAP_S explicitly. The cap remains env-driven so fixture-bound dev setups can lower it to keep test runs snappy. The accompanying docstring in executor.py is updated to match the new rationale. --- README.md | 2 +- src/cfdb/workflows/__init__.py | 7 +++++-- src/cfdb/workflows/executor.py | 9 +++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0bfa885..640b1f3 100644 --- a/README.md +++ b/README.md @@ -569,7 +569,7 @@ Required environment variables: Optional tunables (with defaults): -- `CFDB_WORKFLOW_DURATION_CAP_S` — per-workflow wall-clock cap (default `1200`). +- `CFDB_WORKFLOW_DURATION_CAP_S` — per-workflow wall-clock cap (default `14400`, i.e. 4 h — sized for multi-hour preprocessing runs; lower it for fixture-bound dev). - `CFDB_WORKFLOW_DISPATCH_WAIT_S` — how long `ensure_workflow` waits for a free worker before giving up (default `60`). - `CFDB_WORKFLOW_HEARTBEAT_INTERVAL_S` — cadence at which the wool routine emits heartbeat events during quiet stages so the API can refresh `JobRecord.updated_at` (default `300`). The stale-reclaim threshold below is sized as `2 × heartbeat + safety_margin`; lowering this knob without also lowering the threshold widens the false-reclaim window. - `CFDB_WORKFLOW_STALE_THRESHOLD_S` — `updated_at` age beyond which an active row is reclaimable (default `900`; sized as `2 × heartbeat_interval + safety_margin` so a single missed heartbeat does not falsely reclaim a healthy worker). diff --git a/src/cfdb/workflows/__init__.py b/src/cfdb/workflows/__init__.py index 0776c68..5088e82 100644 --- a/src/cfdb/workflows/__init__.py +++ b/src/cfdb/workflows/__init__.py @@ -27,7 +27,10 @@ - ``CFDB_WORKFLOW_DURATION_CAP_S`` — per-workflow wall-clock cap, enforced via ``asyncio.timeout`` on the API side while consuming the routine's - event stream. Default ``1200`` (20 min). + event stream. Default ``14400`` (4 h) — sized for multi-hour + preprocessing runs (e.g., a ``samtools sort`` on a multi-GB BAM + followed by ``samtools index``). Operators running on bounded fixtures + in dev should lower this via env. - ``CFDB_WORKFLOW_DISPATCH_WAIT_S`` — how long ``ensure_workflow`` waits for a wool worker to become available before giving up. Default ``60``. - ``CFDB_WORKFLOW_HEARTBEAT_INTERVAL_S`` — how often the routine emits a @@ -109,7 +112,7 @@ def _positive_int(name: str, value: str, *, minimum: int = 0) -> int: # look like a hard failure). WORKFLOW_DURATION_CAP_S: Final = _positive_int( "CFDB_WORKFLOW_DURATION_CAP_S", - os.getenv("CFDB_WORKFLOW_DURATION_CAP_S", "1200"), + os.getenv("CFDB_WORKFLOW_DURATION_CAP_S", "14400"), minimum=1, ) WORKFLOW_DISPATCH_WAIT_S: Final = _positive_int( diff --git a/src/cfdb/workflows/executor.py b/src/cfdb/workflows/executor.py index 974facb..75c4ace 100644 --- a/src/cfdb/workflows/executor.py +++ b/src/cfdb/workflows/executor.py @@ -72,10 +72,11 @@ #: change to the workflow orchestration itself). PIPELINE_VERSION = 1 -#: Default job-runtime cap. 20 minutes covers a sort+index on a multi-GB -#: BAM comfortably; exceptional files that need longer should be -#: addressed individually. Sourced from the env var so deployments can -#: override without a code change. +#: Default job-runtime cap. 4 hours covers multi-hour preprocessing runs +#: (e.g., a ``samtools sort`` on a multi-GB BAM followed by ``samtools +#: index``); exceptional files that need longer should be addressed +#: individually. Sourced from the env var so deployments can override +#: without a code change — fixture-bound dev setups should lower it. DEFAULT_WORKFLOW_DURATION_CAP_SECONDS = WORKFLOW_DURATION_CAP_S #: Total wall-clock budget waiting for a leased worker to surface during From 0bf6617855aa1ef47d24cbdf8506ef471695eb4b Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 20 May 2026 12:30:11 -0400 Subject: [PATCH 6/7] feat(workflows): Wire EcsProvisioner.request into WoolExecutor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ensure_workflow now requests a worker from an externally-injected EcsProvisioner on a fresh claim, dedup-keyed on the workflow mutex so two concurrent claims for the same source file share one RunTask and one worker. The request is awaited in _run_workflow between mark_running and the routine-stream open so a capacity / ENI / throttling failure (surfaced as RetryableProvisionerError) routes through the same FAILED terminal path as a stream-open failure, with a "provisioner:" prefix preserved on the persisted error so the operator can tell the two apart in /jobs/{id}. The provisioner ctor arg defaults to None — the PoC dev profile that relies on manually-started wool workers via LanDiscovery is unchanged. The EcsProvisioner type is imported under TYPE_CHECKING so executor.py imports stay boto3-free when the provisioner isn't in use. TestWoolExecutorWithProvisioner covers the three observable shapes: request issued on fresh claim with the workflow_key dedup_key, request suppressed on attach to an already-claimed workflow, and capacity failures landing as FAILED with the provisioner error preserved. --- src/cfdb/workflows/executor.py | 81 ++++++++++- tests/test_workflows/test_executor.py | 192 ++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 3 deletions(-) diff --git a/src/cfdb/workflows/executor.py b/src/cfdb/workflows/executor.py index 75c4ace..e4bf907 100644 --- a/src/cfdb/workflows/executor.py +++ b/src/cfdb/workflows/executor.py @@ -42,7 +42,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import wool @@ -64,6 +64,10 @@ from cfdb.workflows.models import ArtifactKind, JobRecord, JobStatus from cfdb.workflows.processors.base import Processor from cfdb.workflows.processors.registry import ProcessorRegistry +from cfdb.workflows.provisioner import RetryableProvisionerError + +if TYPE_CHECKING: + from cfdb.workflows.provisioner import EcsProvisioner logger = logging.getLogger(__name__) @@ -224,7 +228,34 @@ class instance, ``file_meta`` is a plain dict (B1 strips Mongo's class WoolExecutor(JobExecutor): - """Executor backed by a ``wool.WorkerPool`` and a Mongo jobs collection.""" + """Executor backed by a ``wool.WorkerPool`` and a Mongo jobs collection. + + When configured with an :class:`EcsProvisioner`, the executor also + issues a ``RunTask`` on each fresh claim before opening the routine + stream, so a Fargate worker boots in parallel with the dispatch + retry loop. Without a provisioner (PoC dev profile) the executor + relies on workers already published into the discovery namespace. + + A :class:`RetryableProvisionerError` from the provisioner surfaces + as a terminal ``FAILED`` job with a ``capacity:``-prefixed error + string; any other provisioner failure surfaces with a + ``provisioner:`` prefix. Clients parse the prefix to decide + whether to resubmit. + + Args: + db: Motor database handle holding the ``jobs`` collection. + cache: Cache backend for processor artifacts. + cache_root: Directory used by ``LocalFsCache``; ignored for S3. + registry: Processor registry mapping artifact kinds to runners. + workdir_root: Parent directory under which per-job workdirs land. + pipeline_version: Embedded in workflow keys; bump to invalidate + every in-flight workflow. + workflow_duration_cap_seconds: Per-workflow wall-clock cap. + Defaults to ``DEFAULT_WORKFLOW_DURATION_CAP_SECONDS``. + provisioner: Optional :class:`EcsProvisioner`. When ``None`` + (PoC profile), no ``RunTask`` is issued and dispatch + relies on pre-existing workers. + """ def __init__( self, @@ -236,6 +267,7 @@ def __init__( workdir_root: Path, pipeline_version: int = PIPELINE_VERSION, workflow_duration_cap_seconds: int = DEFAULT_WORKFLOW_DURATION_CAP_SECONDS, + provisioner: Optional[EcsProvisioner] = None, ) -> None: self._db = db self._cache = cache @@ -245,6 +277,12 @@ def __init__( self._workdir_root.mkdir(parents=True, exist_ok=True) self._pipeline_version = pipeline_version self._workflow_duration_cap_seconds = workflow_duration_cap_seconds + #: External worker provisioner. When set, ``_run_workflow`` issues + #: a ``RunTask``-equivalent request to it on a fresh claim so a + #: Fargate worker boots in parallel with the dispatch retry loop. + #: When unset (PoC dev profile), the executor relies entirely on + #: workers already published into the discovery namespace. + self._provisioner = provisioner self._pending_tasks: set[asyncio.Task] = set() #: Finalize tasks (release_workflow + workdir cleanup) created by #: _run_workflow's `finally`. Tracked separately so drain() can @@ -288,7 +326,9 @@ async def ensure_workflow( ) if fresh: - task = asyncio.create_task(self._run_workflow(record, processor, file_meta)) + task = asyncio.create_task( + self._run_workflow(record, processor, file_meta, wf_key) + ) self._pending_tasks.add(task) task.add_done_callback(self._pending_tasks.discard) task.add_done_callback(_log_unexpected_exception) @@ -356,6 +396,7 @@ async def _run_workflow( record: JobRecord, processor: Processor, file_meta: dict[str, Any], + workflow_key: str, ) -> None: """Background coroutine: mark running, consume the routine's event stream, and release a terminal status. @@ -390,6 +431,40 @@ async def _run_workflow( final_error = f"mark_running failed: {exc}" return + # Request a worker via the external provisioner (e.g. ECS + # ``RunTask``) before opening the routine stream. The + # provisioner dedup-keys on the workflow mutex so two + # concurrent fresh claims for the same source file + # share one ``RunTask`` and one worker. A capacity / + # ENI / throttling failure surfaces as + # ``RetryableProvisionerError`` and is routed to a + # terminal ``FAILED`` via the same error path as a + # stream-open failure below. + if self._provisioner is not None: + try: + await self._provisioner.request(dedup_key=workflow_key) + except asyncio.CancelledError: + final_error = "Workflow cancelled (worker shutdown)" + raise + except RetryableProvisionerError as exc: + # Capacity / ENI exhaustion / throttling / boto + # transport transients. ``capacity:`` is the + # stable on-wire prefix clients parse to decide + # whether to resubmit the job. + logger.warning( + "Provisioner reported retryable capacity error for %s: %s", + record.job_id, + exc, + ) + final_error = f"capacity: {exc}" + return + except Exception as exc: + logger.exception( + "Provisioner request failed for %s", record.job_id + ) + final_error = f"provisioner: {exc}" + return + try: stream = await self._open_stream_with_retry( processor, file_meta, workdir diff --git a/tests/test_workflows/test_executor.py b/tests/test_workflows/test_executor.py index ab9e346..9c3e355 100644 --- a/tests/test_workflows/test_executor.py +++ b/tests/test_workflows/test_executor.py @@ -11,6 +11,7 @@ import wool from cfdb.workflows import executor as executor_module +from cfdb.workflows import keys as key_utils from cfdb.workflows.executor import ( PIPELINE_VERSION, ExecutorDraining, @@ -22,6 +23,7 @@ from cfdb.workflows.models import ACTIVE_STATUSES, ArtifactKind, JobStatus from cfdb.workflows.processors.base import Processor from cfdb.workflows.processors.registry import ProcessorRegistry +from cfdb.workflows.provisioner import EcsProvisioner, RetryableProvisionerError from tests.test_workflows import FIXTURE_MD5 #: Canonical artifact keys the stubs emit. Built from FIXTURE_MD5 so the @@ -1241,3 +1243,193 @@ async def test_ensure_workflow_should_persist_file_meta_snapshot_on_insert( final = await get_job(mock_db, record.job_id) assert final is not None assert final.file_meta_snapshot == meta + + +class TestWoolExecutorWithProvisioner: + @pytest.mark.asyncio + async def test_ensure_workflow_should_request_worker_when_provisioner_set( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that a fresh claim dispatches a provisioner request. + + Given: + A WoolExecutor wired with a stub EcsProvisioner. + When: + ``ensure_workflow`` is awaited on a previously-unseen file. + Then: + The provisioner's ``request`` should be awaited exactly once + with ``dedup_key`` set to the workflow mutex key for the file. + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.return_value = ["arn:fake:task/abc"] + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + meta = _file_meta() + dcc, local_id, md5 = extract_identity(meta) + expected_key = key_utils.workflow_key( + dcc=dcc, local_id=local_id, md5=md5, pipeline_version=PIPELINE_VERSION + ) + + # Act + record, _ = await executor.ensure_workflow(meta) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + provisioner.request.assert_awaited_once_with(dedup_key=expected_key) + + @pytest.mark.asyncio + async def test_ensure_workflow_should_skip_provisioner_when_attaching_to_existing_job( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that the provisioner is only invoked on fresh claims. + + Given: + An executor with a stub provisioner and a blocking processor + so the first claim stays active when the second call arrives. + When: + ``ensure_workflow`` is awaited twice for the same file_meta. + Then: + ``provisioner.request`` should be awaited exactly once — the + attach path must not double-spend a Fargate worker against + an already-claimed workflow. + """ + # Arrange + _install_jobs_index(mock_db) + release = asyncio.Event() + + class _BlockingProcessor(_StubProcessor): + async def run(self, file_meta, workdir, cache_root): + self.run_calls += 1 + await release.wait() + for kind, key in self.artifacts.items(): + yield {"event": "stage_complete", "kind": kind, "key": key} + yield {"event": "complete", "artifacts": dict(self.artifacts)} + + registry = ProcessorRegistry() + registry.register(_BlockingProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.return_value = ["arn:fake:task/abc"] + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record_a, fresh_a = await executor.ensure_workflow(_file_meta()) + record_b, fresh_b = await executor.ensure_workflow(_file_meta()) + release.set() + await _wait_for_terminal(mock_db, record_a.job_id) + + # Assert + assert fresh_a is True + assert fresh_b is False + assert record_a.job_id == record_b.job_id + provisioner.request.assert_awaited_once() + + @pytest.mark.asyncio + async def test_ensure_workflow_should_record_capacity_failure_when_provisioner_raises( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that a provisioner failure routes to a clean FAILED status. + + Given: + A provisioner whose ``request`` raises + ``RetryableProvisionerError`` (capacity / ENI exhaustion / + throttling). + When: + ``ensure_workflow`` is awaited. + Then: + The job should reach ``FAILED`` with the provisioner error + string captured on the record so the next client retry can + see why the workflow didn't run. + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.side_effect = RetryableProvisionerError( + "ClientError: capacity-unavailable" + ) + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record, _ = await executor.ensure_workflow(_file_meta()) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + final = await get_job(mock_db, record.job_id) + assert final is not None + assert final.status == JobStatus.FAILED + assert final.error is not None + assert final.error.startswith("capacity: "), ( + f"expected 'capacity: ' prefix, got {final.error!r}" + ) + assert "capacity-unavailable" in final.error + + @pytest.mark.asyncio + async def test_ensure_workflow_should_record_provisioner_prefix_for_generic_exception( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that non-retryable provisioner failures use the 'provisioner: ' prefix. + + Given: + A provisioner whose ``request`` raises a generic + ``RuntimeError`` (a misconfiguration or unexpected boto + failure, not a retryable transient). + When: + ``ensure_workflow`` is awaited. + Then: + The job should reach ``FAILED`` with an error string that + starts with ``"provisioner: "`` so clients can distinguish + "retry me later" from "the provisioner crashed". + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.side_effect = RuntimeError("misconfigured cluster") + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record, _ = await executor.ensure_workflow(_file_meta()) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + final = await get_job(mock_db, record.job_id) + assert final is not None + assert final.status == JobStatus.FAILED + assert final.error is not None + assert final.error.startswith("provisioner: "), ( + f"expected 'provisioner: ' prefix, got {final.error!r}" + ) + assert "misconfigured cluster" in final.error From 950686b11b3bf3e4f0a2ebd173a90b8db46006b4 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 20 May 2026 12:30:28 -0400 Subject: [PATCH 7/7] feat(api): Wire S3 cache and ECS provisioner into the lifespan MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the AWS / ECS env config (AWS_ENDPOINT_URL, AWS_REGION, WORKFLOW_S3_BUCKET, WORKFLOW_S3_PREFIX, ECS_CLUSTER, ECS_WORKER_TASK_DEFINITION, ECS_WORKER_TASK_FAMILY, ECS_WORKER_SUBNETS, ECS_WORKER_SECURITY_GROUPS, ECS_WORKER_ASSIGN_PUBLIC_IP) to cfdb.api and switch the lifespan to pick the right runtime backends per env state. Three helpers gate the selection so each concern stays separable: _build_cache returns S3Cache when WORKFLOW_S3_BUCKET is set and LocalFsCache otherwise; _maybe_build_provisioner returns an EcsProvisioner when ECS_CLUSTER / task-def / subnets are all set and None otherwise; _build_discovery wraps EcsDiscovery in its async context (or yields LanDiscovery unchanged) so the lifespan's WorkerPool block opens against either discovery with the same shape. ECS_WORKER_TASK_FAMILY defaults to ECS_WORKER_TASK_DEFINITION with any :revision suffix stripped — ListTasks only accepts the family, RunTask accepts family[:revision]. The bare PoC profile (no AWS env set) keeps producing LocalFsCache + LanDiscovery + no provisioner, identical to the path before this change. The EXDEV cross-filesystem check is now gated on the LocalFsCache branch since S3 has no rename-atomicity precondition. Startup log reports the resolved cache / discovery / provisioner types instead of the LAN namespace so operators can see at a glance which profile activated. --- README.md | 17 ++ src/cfdb/api/__init__.py | 135 +++++++++++++++- src/cfdb/api/main.py | 240 +++++++++++++++++++++-------- tests/test_api/__init__.py | 0 tests/test_api/test_env_parsers.py | 75 +++++++++ 5 files changed, 403 insertions(+), 64 deletions(-) create mode 100644 tests/test_api/__init__.py create mode 100644 tests/test_api/test_env_parsers.py diff --git a/README.md b/README.md index 640b1f3..3bbb78a 100644 --- a/README.md +++ b/README.md @@ -578,6 +578,23 @@ Optional tunables (with defaults): Required tools on `PATH` for the **worker pool** (not the API): `samtools`, `bgzip`, `tabix`, `bcftools`, `gffread`, `bigBedToBed`. The `api` Docker image already installs all of these — the simplest local-dev / single-host deployment is to reuse `Dockerfile.api` as the worker image and override the `CMD` (or run the wool worker entrypoint via `python -m wool`). On the worker host, set `WORKFLOW_POOL_NAMESPACE` to match the API's value. +#### ECS Fargate profile + +When the API runs on ECS Fargate (or LocalStack-backed dev that mirrors prod end-to-end), the lifespan switches from `LocalFsCache` + `LanDiscovery` to `S3Cache` + `EcsDiscovery` + `EcsProvisioner`. The selection is env-driven; with none of the variables below set the API runs the local PoC profile unchanged. + +- `AWS_ENDPOINT_URL` — boto3 endpoint override. Unset in production (boto3 hits real AWS); set to `http://localstack:4566` (or similar) for LocalStack-backed dev. The same application code runs in both environments — only this variable differs. +- `AWS_REGION` — AWS region for the boto3 client (default `us-east-1`). +- `WORKFLOW_S3_BUCKET` — when set, the lifespan instantiates `S3Cache` instead of `LocalFsCache`. The bucket must already exist (creation is out of band). When unset, the API stays on the local filesystem cache. +- `WORKFLOW_S3_PREFIX` — optional key prefix the S3 backend prepends to every cache key (default empty). Lets a single bucket host multiple environments (`dev/`, `staging/`, `prod/`) without collisions. +- `ECS_CLUSTER` — ECS cluster name or ARN. Gates the ECS-backed provisioner and discovery profile; unset means the PoC profile stays on `LanDiscovery` with no provisioner. +- `ECS_WORKER_TASK_DEFINITION` — task definition for the worker container, as a family name (`cfdb-worker`) or `family:revision`. The provisioner passes it through to `RunTask` verbatim. +- `ECS_WORKER_TASK_FAMILY` — family used by `EcsDiscovery` to filter `ListTasks`. Defaults to `ECS_WORKER_TASK_DEFINITION` with any `:revision` suffix stripped; set explicitly only when the discovery family differs from the provisioner task-def family (rare). +- `ECS_WORKER_SUBNETS` — comma-separated awsvpc subnet IDs the worker ENIs land in. Required for the ECS profile; an empty list with `ECS_CLUSTER` set is a misconfiguration. +- `ECS_WORKER_SECURITY_GROUPS` — comma-separated awsvpc security group IDs. Optional — when empty, ECS applies the VPC default SG. +- `ECS_WORKER_ASSIGN_PUBLIC_IP` — `ENABLED` or `DISABLED` (default `DISABLED`). Production should leave this disabled and reach AWS via VPC endpoints; LocalStack accepts either value. + +The worker container's `CMD` is `python -m cfdb.workflows.worker_main`. Worker-side knobs (gRPC port, health port, max lifetime, drain grace) are documented under `--help` on that command; their env vars are `CFDB_WORKER_GRPC_PORT`, `CFDB_WORKER_HEALTH_PORT`, `CFDB_WORKER_MAX_LIFETIME_SECONDS`, and `CFDB_WORKER_DRAIN_GRACE_SECONDS`. The worker task definition MUST declare a `healthCheck` against the gRPC port; without one ECS reports `healthStatus: UNKNOWN` indefinitely and the worker is never advertised to discovery. + #### Running a local worker pool For single-host development, start a wool worker pool in a separate process before launching the API, with `WORKFLOW_POOL_NAMESPACE` matching what the API uses. The API connects via LAN discovery (zeroconf/mDNS) and dispatches workflows to whatever workers are publishing under that namespace. With no worker pool running, `/data` and `/index` requests for processable formats will hang on the dispatch retry budget (60s by default) before failing with `NoWorkersAvailable`. diff --git a/src/cfdb/api/__init__.py b/src/cfdb/api/__init__.py index e5cbb54..8b2760d 100644 --- a/src/cfdb/api/__init__.py +++ b/src/cfdb/api/__init__.py @@ -4,6 +4,8 @@ from motor.motor_asyncio import AsyncIOMotorDatabase +from cfdb.workflows import WORKFLOW_DURATION_CAP_S + if TYPE_CHECKING: from cfdb.workflows.cache import CacheBackend from cfdb.workflows.executor import JobExecutor @@ -79,10 +81,139 @@ def _parse_int_env(name: str, default: int, *, minimum: int = 0) -> int: #: same string; otherwise the API's discovery service won't see the #: worker registrations and the pool will start with zero leasable #: workers. The default ``"cfdb-workers"`` matches what the worker pool -#: needs to publish under; an ECS-aware variant may supplant LAN -#: discovery in a future PR. +#: needs to publish under. Ignored when the ECS-discovery profile is +#: active (see ``ECS_CLUSTER`` / ``ECS_WORKER_TASK_DEFINITION``); the +#: ECS path discovers workers by polling the ECS control plane directly +#: and does not need a shared zeroconf namespace. WORKFLOW_POOL_NAMESPACE: Final = os.getenv("WORKFLOW_POOL_NAMESPACE", "cfdb-workers") +# AWS / ECS profile. These knobs are optional — when none of them are +# set the API runs the PoC profile (``LocalFsCache`` + ``LanDiscovery`` +# + no worker provisioner) and behaves exactly as it did before the +# Fargate work landed. Production / LocalStack-backed dev sets the +# bucket + cluster + task-def + subnets to activate the ECS-backed +# profile; the same code path serves both because boto3 honors +# ``AWS_ENDPOINT_URL`` to redirect at LocalStack. + +#: boto3 endpoint override. In production this is unset and boto3 +#: targets real AWS endpoints; LocalStack-backed dev sets it to +#: ``http://localstack:4566`` so the same code talks to the local +#: container. Threaded through to ``_build_s3_client`` / +#: ``build_ecs_client`` (in ``cfdb.workflows.cache`` / +#: ``cfdb.workflows.provisioner``) via the boto3 ``Session`` default +#: chain, so no per-client wiring is needed here. +AWS_ENDPOINT_URL: Final = os.getenv("AWS_ENDPOINT_URL") + +#: AWS region. Defaults to ``us-east-1`` so a missing ``AWS_REGION`` in +#: dev doesn't surface as an opaque boto3 ``NoRegionError`` at first +#: request — operators get a working default and override it in +#: production deployments. +AWS_REGION: Final = os.getenv("AWS_REGION", "us-east-1") + +#: When set, the lifespan instantiates ``S3Cache`` instead of +#: ``LocalFsCache``. Unset means the API stays on the local +#: filesystem cache. +WORKFLOW_S3_BUCKET: Final = os.getenv("WORKFLOW_S3_BUCKET") + +#: Optional key prefix the S3 backend prepends to every cache key. +#: Lets a single bucket host multiple environments (``dev/``, +#: ``staging/``, ``prod/``) without collisions. +WORKFLOW_S3_PREFIX: Final = os.getenv("WORKFLOW_S3_PREFIX", "") + +#: ECS cluster name or ARN. Gates the ECS-backed provisioner and +#: discovery profile; unset means the PoC profile stays on +#: ``LanDiscovery`` with no provisioner. +ECS_CLUSTER: Final = os.getenv("ECS_CLUSTER") + +#: ECS worker task definition. Accepts either a family name +#: (``cfdb-worker``) or a ``family:revision`` string. The provisioner +#: passes it through to ``RunTask`` verbatim; the discovery loop +#: strips any ``:revision`` suffix to derive its +#: ``family`` filter (see ``ECS_WORKER_TASK_FAMILY`` override below). +ECS_WORKER_TASK_DEFINITION: Final = os.getenv("ECS_WORKER_TASK_DEFINITION") + + +def _ecs_default_task_family() -> str | None: + """Derive the discovery ``family`` filter from the task definition. + + ``RunTask`` accepts ``family[:revision]``; ``ListTasks`` accepts + only the family (no revision). The default split strips the + revision when present, with ``ECS_WORKER_TASK_FAMILY`` available as + an explicit override for environments that pin a non-default + family name. + """ + explicit = os.getenv("ECS_WORKER_TASK_FAMILY") + if explicit: + return explicit + if ECS_WORKER_TASK_DEFINITION: + return ECS_WORKER_TASK_DEFINITION.split(":", 1)[0] + return None + + +#: Family used by ``EcsDiscovery`` to filter ``ListTasks``. Derived +#: from ``ECS_WORKER_TASK_DEFINITION`` by default; set explicitly via +#: ``ECS_WORKER_TASK_FAMILY`` only when the discovery family differs +#: from the provisioner task-def family (rare). +ECS_WORKER_TASK_FAMILY: Final = _ecs_default_task_family() + + +def _parse_csv_env(name: str, default: str = "") -> list[str]: + """Parse a comma-separated env var into a list of trimmed strings. + + Empty strings are dropped so a trailing comma or double comma + doesn't propagate as an empty subnet/SG entry that boto3 would + later reject with a less informative error. + """ + raw = os.getenv(name, default) + return [item.strip() for item in raw.split(",") if item.strip()] + + +#: Awsvpc subnet IDs the worker ENIs land in. Required for the ECS +#: profile; an empty list with ``ECS_CLUSTER`` set is a misconfiguration +#: that the lifespan refuses to start under. +ECS_WORKER_SUBNETS: Final = _parse_csv_env("ECS_WORKER_SUBNETS") + +#: Awsvpc security groups attached to the worker ENIs. Optional — +#: when empty, ECS applies the VPC default SG. +ECS_WORKER_SECURITY_GROUPS: Final = _parse_csv_env("ECS_WORKER_SECURITY_GROUPS") + +_ASSIGN_PUBLIC_IP_VALUES = frozenset({"ENABLED", "DISABLED"}) + + +def _parse_assign_public_ip(name: str, default: str) -> str: + """Parse an ECS ``assignPublicIp`` env var with explicit validation. + + ECS rejects anything other than ``ENABLED`` / ``DISABLED``; we + surface the misconfiguration at module-import time so the lifespan + doesn't get to Mongo + S3 init before tripping on it. Pairs with + :data:`cfdb.workflows.provisioner._ASSIGN_PUBLIC_IP_VALUES`. + """ + raw = os.getenv(name) + if raw is None or raw == "": + return default + if raw not in _ASSIGN_PUBLIC_IP_VALUES: + raise ImportError( + f"Environment variable {name}={raw!r} must be one of " + f"{sorted(_ASSIGN_PUBLIC_IP_VALUES)}" + ) + return raw + + +#: Whether the worker ENI gets a public IPv4 address. Production +#: should leave this DISABLED and reach AWS via VPC endpoints; +#: LocalStack accepts either value. +ECS_WORKER_ASSIGN_PUBLIC_IP: Final = _parse_assign_public_ip( + "ECS_WORKER_ASSIGN_PUBLIC_IP", "DISABLED" +) + +#: Per-deployment override for the workflow runtime cap. Threads +#: through to :class:`cfdb.workflows.executor.WoolExecutor` via the +#: lifespan. Defaults to :data:`cfdb.workflows.WORKFLOW_DURATION_CAP_S` +#: (the env-driven workflows-package default, currently 4 h). +WORKFLOW_DURATION_CAP_SECONDS: Final = _parse_int_env( + "WORKFLOW_DURATION_CAP_SECONDS", default=WORKFLOW_DURATION_CAP_S, minimum=1 +) + db: AsyncIOMotorDatabase | None = None cache: "CacheBackend | None" = None executor: "JobExecutor | None" = None diff --git a/src/cfdb/api/main.py b/src/cfdb/api/main.py index 62ca7c6..bf31c8b 100644 --- a/src/cfdb/api/main.py +++ b/src/cfdb/api/main.py @@ -4,6 +4,7 @@ import re from contextlib import asynccontextmanager from pathlib import Path +from typing import AsyncIterator, Optional import wool from wool.runtime.discovery.lan import LanDiscovery @@ -19,12 +20,19 @@ from cfdb.api.routers.index import router as index_router from cfdb.api.routers.jobs import router as jobs_router from cfdb.api.routers.sync import router as sync_router -from cfdb.workflows.cache import LocalFsCache +from cfdb.workflows.cache import ( + CacheBackend, + LocalFsCache, + S3Cache, + check_s3_bucket_or_raise, +) +from cfdb.workflows.discovery import EcsDiscovery from cfdb.workflows.executor import WoolExecutor from cfdb.workflows.models import ACTIVE_STATUSES from cfdb.workflows.processors.bam import BamIndexProcessor from cfdb.workflows.processors.registry import default_registry from cfdb.workflows.processors.tabix import TabixIntervalProcessor +from cfdb.workflows.provisioner import EcsProvisioner logging.basicConfig(level=logging.INFO) @@ -86,6 +94,91 @@ async def _assert_jobs_indexes(db, log: logging.Logger) -> None: ) +async def _build_cache(cache_root: Path) -> CacheBackend: + """Pick the cache backend per env config. + + Returns ``S3Cache`` when ``WORKFLOW_S3_BUCKET`` is set (production + and LocalStack-backed dev) and probes the bucket via + :func:`check_s3_bucket_or_raise` so a typo / missing IAM grant + fails fast at boot rather than as a permanent-cache-miss cascade + once workflows start. Falls back to ``LocalFsCache`` otherwise — + the PoC profile that ships before the deployment substrate is + wired up. ``cache_root`` is only meaningful on the LocalFsCache + branch but is accepted unconditionally so the caller doesn't need + to branch on the env state. + """ + if api.WORKFLOW_S3_BUCKET: + cache = S3Cache( + bucket=api.WORKFLOW_S3_BUCKET, + prefix=api.WORKFLOW_S3_PREFIX, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) + # Reuse the freshly-built client so the probe targets the same + # endpoint as subsequent reads/writes. + await check_s3_bucket_or_raise(api.WORKFLOW_S3_BUCKET, client=cache._client) + return cache + return LocalFsCache(cache_root) + + +def _maybe_build_provisioner() -> Optional[EcsProvisioner]: + """Build an ``EcsProvisioner`` when the ECS env is fully configured. + + Returns ``None`` when any required knob is missing so the executor's + PoC profile (manually-started wool workers, no provisioner) is + preserved. The lifespan logs which path is selected so an + incomplete config doesn't silently degrade. + + The completeness check requires all three of cluster, task + definition, and a non-empty subnet list — these are the minimum + ``RunTask`` needs to launch a worker into the awsvpc network. A + partial config (e.g. cluster + task-def but no subnets) is treated + as the PoC profile rather than failing fast because deployment + pipelines may roll out the env vars in stages. + """ + if not (api.ECS_CLUSTER and api.ECS_WORKER_TASK_DEFINITION and api.ECS_WORKER_SUBNETS): + return None + return EcsProvisioner( + cluster=api.ECS_CLUSTER, + task_definition=api.ECS_WORKER_TASK_DEFINITION, + subnets=api.ECS_WORKER_SUBNETS, + security_groups=api.ECS_WORKER_SECURITY_GROUPS, + assign_public_ip=api.ECS_WORKER_ASSIGN_PUBLIC_IP, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) + + +@asynccontextmanager +async def _build_discovery() -> AsyncIterator[object]: + """Pick the worker-discovery backend per env config. + + Two paths: + * ECS — ``EcsDiscovery`` polls ``ListTasks`` + ``DescribeTasks`` + on the configured cluster. Async-context-managed: the background + poller starts on ``__aenter__`` and is cancelled on + ``__aexit__``. Selected when ``ECS_CLUSTER`` and the worker + task definition are both set. + * LAN — ``LanDiscovery`` over zeroconf/mDNS. The wool pool + accepts it as-is; the wrapper context manager is a no-op so + the caller's ``async with`` works either way. + + Yields the same shape (something wool's ``WorkerPool`` accepts via + its ``discovery=`` arg) so the lifespan can wire it through without + branching at the call site. + """ + if api.ECS_CLUSTER and api.ECS_WORKER_TASK_FAMILY: + async with EcsDiscovery( + cluster=api.ECS_CLUSTER, + task_definition_family=api.ECS_WORKER_TASK_FAMILY, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) as discovery: + yield discovery + return + yield LanDiscovery(api.WORKFLOW_POOL_NAMESPACE) + + def create_mongodb_client() -> AsyncIOMotorClient: """Create MongoDB client with optional TLS authentication.""" log = logging.getLogger(__name__) @@ -146,79 +239,102 @@ async def lifespan(_: FastAPI): # same filesystem (otherwise the kernel raises # ``OSError(EXDEV)``). Verify the precondition at startup so # a multi-volume deployment fails fast with a clear message - # instead of dying mid-pipeline on the first cache.put. - cache_st = os.stat(cache_root) - workdir_st = os.stat(workdir_root) - if cache_st.st_dev != workdir_st.st_dev: - raise RuntimeError( - "SYNC_DATA_DIR subdirectories must share a filesystem " - f"(cache={cache_root!s} st_dev={cache_st.st_dev}, " - f"workdir={workdir_root!s} st_dev={workdir_st.st_dev}). " - "LocalFsCache.put relies on os.replace atomicity; " - "cross-device renames raise OSError(EXDEV). Mount both " - "paths under a single volume or set SYNC_DATA_DIR to a " - "parent that contains both." - ) + # instead of dying mid-pipeline on the first cache.put. Only + # the LocalFsCache branch needs this — S3Cache.put goes + # over the network and has no rename-atomicity requirement. + if not api.WORKFLOW_S3_BUCKET: + cache_st = os.stat(cache_root) + workdir_st = os.stat(workdir_root) + if cache_st.st_dev != workdir_st.st_dev: + raise RuntimeError( + "SYNC_DATA_DIR subdirectories must share a filesystem " + f"(cache={cache_root!s} st_dev={cache_st.st_dev}, " + f"workdir={workdir_root!s} st_dev={workdir_st.st_dev}). " + "LocalFsCache.put relies on os.replace atomicity; " + "cross-device renames raise OSError(EXDEV). Mount both " + "paths under a single volume or set SYNC_DATA_DIR to a " + "parent that contains both." + ) - api.cache = LocalFsCache(cache_root) + api.cache = await _build_cache(cache_root) api.processor_registry = default_registry() api.processor_registry.register(BamIndexProcessor()) api.processor_registry.register(TabixIntervalProcessor()) + provisioner = _maybe_build_provisioner() + # Warn loudly when discovery is ECS-mode but the + # provisioner is unavailable — workers polled-for are not + # workers launched-on-demand, so the dispatch retry budget + # eventually exhausts with ``NoWorkersAvailable``. The + # asymmetry is intentional to support staged env rollouts + # but operators should know they're in a half-configured + # deployment. + if api.ECS_CLUSTER and api.ECS_WORKER_TASK_FAMILY and provisioner is None: + log.warning( + "ECS discovery is active but no provisioner — set " + "ECS_WORKER_TASK_DEFINITION and ECS_WORKER_SUBNETS to " + "complete the ECS profile, or unset ECS_CLUSTER to " + "fall back to the PoC profile." + ) + # Lease workers from the surrounding pool rather than spawning - # them in-process. In production the workers run as separate - # ECS tasks discovered via wool's discovery layer; the API's - # job is to dispatch to whatever capacity exists. Scaling the - # ECS service is out of band (e.g., on ``NoWorkersAvailable`` - # bursts or on a queue-depth metric). + # them in-process. In production the workers run as ECS Fargate + # tasks launched on demand by ``EcsProvisioner`` and discovered + # via ``EcsDiscovery``'s poll-and-diff over ``ListTasks`` / + # ``DescribeTasks``. In the PoC dev profile no AWS env is set + # and the API falls back to ``LanDiscovery`` (zeroconf/mDNS) + # against a manually-started wool pool. # # The explicit ``discovery=`` is required to keep wool out of # its default ephemeral mode — ``WorkerPool(lease=N)`` alone # falls into the ``(spawn=None, discovery=None)`` branch which - # spawns CPU-count workers locally. Pairing ``lease=N`` with a - # shared ``LanDiscovery`` namespace puts the pool in - # discovery-only mode (no spawning); the worker-pool process - # publishes workers via the same namespace. ``LanDiscovery`` - # rides over zeroconf/mDNS, which sidesteps the macOS - # ``watchdog``/FSEvents fork-unsafety we hit with - # ``LocalDiscovery``. - async with wool.WorkerPool( - discovery=LanDiscovery(api.WORKFLOW_POOL_NAMESPACE), - lease=api.WORKFLOW_WORKER_COUNT, - ): - # Snapshot the lifespan task's contextvars after the - # pool's ``__aenter__`` has populated wool's internals. - api.wool_context = contextvars.copy_context() - api.executor = WoolExecutor( - api.db, - api.cache, - cache_root, - api.processor_registry, - workdir_root=workdir_root, - ) - executor_handle = api.executor - log.info( - "Workflow subsystem enabled: cache=%s workdir=%s " - "lease=%d namespace=%s", - cache_root, - workdir_root, - api.WORKFLOW_WORKER_COUNT, - api.WORKFLOW_POOL_NAMESPACE, - ) - try: - yield - finally: - drained = await executor_handle.drain( - timeout=SHUTDOWN_DRAIN_TIMEOUT_SECONDS + # spawns CPU-count workers locally. + async with _build_discovery() as discovery: + async with wool.WorkerPool( + discovery=discovery, + lease=api.WORKFLOW_WORKER_COUNT, + ): + # Snapshot the lifespan task's contextvars after the + # pool's ``__aenter__`` has populated wool's internals. + api.wool_context = contextvars.copy_context() + api.executor = WoolExecutor( + api.db, + api.cache, + cache_root, + api.processor_registry, + workdir_root=workdir_root, + workflow_duration_cap_seconds=api.WORKFLOW_DURATION_CAP_SECONDS, + provisioner=provisioner, + ) + executor_handle = api.executor + log.info( + "Workflow subsystem enabled: cache=%s workdir=%s " + "lease=%d discovery=%s provisioner=%s", + type(api.cache).__name__, + workdir_root, + api.WORKFLOW_WORKER_COUNT, + type(discovery).__name__, + "EcsProvisioner" if provisioner is not None else "none", ) - if drained: - log.info( - "Drained %d workflow task(s) on shutdown", drained + try: + yield + finally: + drained = await executor_handle.drain( + timeout=SHUTDOWN_DRAIN_TIMEOUT_SECONDS ) - api.executor = None - api.cache = None - api.processor_registry = None - api.wool_context = None + if drained: + log.info( + "Drained %d workflow task(s) on shutdown", drained + ) + # Cancel in-flight ``RunTask`` calls so a + # shutdown mid-launch doesn't leave a task + # un-requested-but-billed. + if provisioner is not None: + await provisioner.aclose() + api.executor = None + api.cache = None + api.processor_registry = None + api.wool_context = None else: log.info( "SYNC_DATA_DIR unset — workflow subsystem disabled; " diff --git a/tests/test_api/__init__.py b/tests/test_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api/test_env_parsers.py b/tests/test_api/test_env_parsers.py new file mode 100644 index 0000000..c90017a --- /dev/null +++ b/tests/test_api/test_env_parsers.py @@ -0,0 +1,75 @@ +"""Tests for the env-var parsers in :mod:`cfdb.api`.""" + +from __future__ import annotations + +import pytest + +from cfdb.api import _parse_assign_public_ip + + +class TestParseAssignPublicIp: + def test__parse_assign_public_ip_with_unset_var(self, monkeypatch): + """Test that an unset env var returns the supplied default. + + Given: + No environment variable set for the given name. + When: + ``_parse_assign_public_ip`` is called with a default value. + Then: + It should return the default rather than raising so the PoC + profile (no ECS env) still imports cleanly. + """ + # Arrange + monkeypatch.delenv("CFDB_TEST_ASSIGN_PUBLIC_IP", raising=False) + + # Act + result = _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + ) + + # Assert + assert result == "DISABLED" + + def test__parse_assign_public_ip_with_valid_value(self, monkeypatch): + """Test that an in-set value is returned as-is. + + Given: + ``CFDB_TEST_ASSIGN_PUBLIC_IP=ENABLED``. + When: + ``_parse_assign_public_ip`` is called. + Then: + It should return ``"ENABLED"`` unchanged so the lifespan can + hand it to ``EcsProvisioner`` verbatim. + """ + # Arrange + monkeypatch.setenv("CFDB_TEST_ASSIGN_PUBLIC_IP", "ENABLED") + + # Act + result = _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + ) + + # Assert + assert result == "ENABLED" + + def test__parse_assign_public_ip_with_invalid_value(self, monkeypatch): + """Test that an out-of-set value surfaces as an ImportError. + + Given: + ``CFDB_TEST_ASSIGN_PUBLIC_IP`` set to an invalid value + (lowercase, typo, anything outside {ENABLED, DISABLED}). + When: + ``_parse_assign_public_ip`` is called. + Then: + It should raise ``ImportError`` so the misconfiguration + surfaces at module load rather than crashing the lifespan + mid-bootstrap. + """ + # Arrange + monkeypatch.setenv("CFDB_TEST_ASSIGN_PUBLIC_IP", "enabled") + + # Act & assert + with pytest.raises(ImportError, match="must be one of"): + _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + )