diff --git a/README.md b/README.md index 0bfa885..3bbb78a 100644 --- a/README.md +++ b/README.md @@ -569,7 +569,7 @@ Required environment variables: Optional tunables (with defaults): -- `CFDB_WORKFLOW_DURATION_CAP_S` — per-workflow wall-clock cap (default `1200`). +- `CFDB_WORKFLOW_DURATION_CAP_S` — per-workflow wall-clock cap (default `14400`, i.e. 4 h — sized for multi-hour preprocessing runs; lower it for fixture-bound dev). - `CFDB_WORKFLOW_DISPATCH_WAIT_S` — how long `ensure_workflow` waits for a free worker before giving up (default `60`). - `CFDB_WORKFLOW_HEARTBEAT_INTERVAL_S` — cadence at which the wool routine emits heartbeat events during quiet stages so the API can refresh `JobRecord.updated_at` (default `300`). The stale-reclaim threshold below is sized as `2 × heartbeat + safety_margin`; lowering this knob without also lowering the threshold widens the false-reclaim window. - `CFDB_WORKFLOW_STALE_THRESHOLD_S` — `updated_at` age beyond which an active row is reclaimable (default `900`; sized as `2 × heartbeat_interval + safety_margin` so a single missed heartbeat does not falsely reclaim a healthy worker). @@ -578,6 +578,23 @@ Optional tunables (with defaults): Required tools on `PATH` for the **worker pool** (not the API): `samtools`, `bgzip`, `tabix`, `bcftools`, `gffread`, `bigBedToBed`. The `api` Docker image already installs all of these — the simplest local-dev / single-host deployment is to reuse `Dockerfile.api` as the worker image and override the `CMD` (or run the wool worker entrypoint via `python -m wool`). On the worker host, set `WORKFLOW_POOL_NAMESPACE` to match the API's value. +#### ECS Fargate profile + +When the API runs on ECS Fargate (or LocalStack-backed dev that mirrors prod end-to-end), the lifespan switches from `LocalFsCache` + `LanDiscovery` to `S3Cache` + `EcsDiscovery` + `EcsProvisioner`. The selection is env-driven; with none of the variables below set the API runs the local PoC profile unchanged. + +- `AWS_ENDPOINT_URL` — boto3 endpoint override. Unset in production (boto3 hits real AWS); set to `http://localstack:4566` (or similar) for LocalStack-backed dev. The same application code runs in both environments — only this variable differs. +- `AWS_REGION` — AWS region for the boto3 client (default `us-east-1`). +- `WORKFLOW_S3_BUCKET` — when set, the lifespan instantiates `S3Cache` instead of `LocalFsCache`. The bucket must already exist (creation is out of band). When unset, the API stays on the local filesystem cache. +- `WORKFLOW_S3_PREFIX` — optional key prefix the S3 backend prepends to every cache key (default empty). Lets a single bucket host multiple environments (`dev/`, `staging/`, `prod/`) without collisions. +- `ECS_CLUSTER` — ECS cluster name or ARN. Gates the ECS-backed provisioner and discovery profile; unset means the PoC profile stays on `LanDiscovery` with no provisioner. +- `ECS_WORKER_TASK_DEFINITION` — task definition for the worker container, as a family name (`cfdb-worker`) or `family:revision`. The provisioner passes it through to `RunTask` verbatim. +- `ECS_WORKER_TASK_FAMILY` — family used by `EcsDiscovery` to filter `ListTasks`. Defaults to `ECS_WORKER_TASK_DEFINITION` with any `:revision` suffix stripped; set explicitly only when the discovery family differs from the provisioner task-def family (rare). +- `ECS_WORKER_SUBNETS` — comma-separated awsvpc subnet IDs the worker ENIs land in. Required for the ECS profile; an empty list with `ECS_CLUSTER` set is a misconfiguration. +- `ECS_WORKER_SECURITY_GROUPS` — comma-separated awsvpc security group IDs. Optional — when empty, ECS applies the VPC default SG. +- `ECS_WORKER_ASSIGN_PUBLIC_IP` — `ENABLED` or `DISABLED` (default `DISABLED`). Production should leave this disabled and reach AWS via VPC endpoints; LocalStack accepts either value. + +The worker container's `CMD` is `python -m cfdb.workflows.worker_main`. Worker-side knobs (gRPC port, health port, max lifetime, drain grace) are documented under `--help` on that command; their env vars are `CFDB_WORKER_GRPC_PORT`, `CFDB_WORKER_HEALTH_PORT`, `CFDB_WORKER_MAX_LIFETIME_SECONDS`, and `CFDB_WORKER_DRAIN_GRACE_SECONDS`. The worker task definition MUST declare a `healthCheck` against the gRPC port; without one ECS reports `healthStatus: UNKNOWN` indefinitely and the worker is never advertised to discovery. + #### Running a local worker pool For single-host development, start a wool worker pool in a separate process before launching the API, with `WORKFLOW_POOL_NAMESPACE` matching what the API uses. The API connects via LAN discovery (zeroconf/mDNS) and dispatches workflows to whatever workers are publishing under that namespace. With no worker pool running, `/data` and `/index` requests for processable formats will hang on the dispatch retry budget (60s by default) before failing with `NoWorkersAvailable`. diff --git a/pyproject.toml b/pyproject.toml index 6dda16c..52308b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ classifiers = [ ] dependencies = [ "aiohttp", + "boto3", "click", "debugpy", "fastapi", @@ -30,7 +31,7 @@ readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" [project.optional-dependencies] -dev = ["allpairspy", "debugpy", "httpx", "hypothesis", "mongomock-motor", "pytest-asyncio", "pytest-mock", "ruff"] +dev = ["allpairspy", "debugpy", "httpx", "hypothesis", "mongomock-motor", "moto[ecs,s3]", "pytest-asyncio", "pytest-mock", "ruff"] [project.scripts] cfdb = "cfdb.cli:cli" diff --git a/src/cfdb/api/__init__.py b/src/cfdb/api/__init__.py index e5cbb54..8b2760d 100644 --- a/src/cfdb/api/__init__.py +++ b/src/cfdb/api/__init__.py @@ -4,6 +4,8 @@ from motor.motor_asyncio import AsyncIOMotorDatabase +from cfdb.workflows import WORKFLOW_DURATION_CAP_S + if TYPE_CHECKING: from cfdb.workflows.cache import CacheBackend from cfdb.workflows.executor import JobExecutor @@ -79,10 +81,139 @@ def _parse_int_env(name: str, default: int, *, minimum: int = 0) -> int: #: same string; otherwise the API's discovery service won't see the #: worker registrations and the pool will start with zero leasable #: workers. The default ``"cfdb-workers"`` matches what the worker pool -#: needs to publish under; an ECS-aware variant may supplant LAN -#: discovery in a future PR. +#: needs to publish under. Ignored when the ECS-discovery profile is +#: active (see ``ECS_CLUSTER`` / ``ECS_WORKER_TASK_DEFINITION``); the +#: ECS path discovers workers by polling the ECS control plane directly +#: and does not need a shared zeroconf namespace. WORKFLOW_POOL_NAMESPACE: Final = os.getenv("WORKFLOW_POOL_NAMESPACE", "cfdb-workers") +# AWS / ECS profile. These knobs are optional — when none of them are +# set the API runs the PoC profile (``LocalFsCache`` + ``LanDiscovery`` +# + no worker provisioner) and behaves exactly as it did before the +# Fargate work landed. Production / LocalStack-backed dev sets the +# bucket + cluster + task-def + subnets to activate the ECS-backed +# profile; the same code path serves both because boto3 honors +# ``AWS_ENDPOINT_URL`` to redirect at LocalStack. + +#: boto3 endpoint override. In production this is unset and boto3 +#: targets real AWS endpoints; LocalStack-backed dev sets it to +#: ``http://localstack:4566`` so the same code talks to the local +#: container. Threaded through to ``_build_s3_client`` / +#: ``build_ecs_client`` (in ``cfdb.workflows.cache`` / +#: ``cfdb.workflows.provisioner``) via the boto3 ``Session`` default +#: chain, so no per-client wiring is needed here. +AWS_ENDPOINT_URL: Final = os.getenv("AWS_ENDPOINT_URL") + +#: AWS region. Defaults to ``us-east-1`` so a missing ``AWS_REGION`` in +#: dev doesn't surface as an opaque boto3 ``NoRegionError`` at first +#: request — operators get a working default and override it in +#: production deployments. +AWS_REGION: Final = os.getenv("AWS_REGION", "us-east-1") + +#: When set, the lifespan instantiates ``S3Cache`` instead of +#: ``LocalFsCache``. Unset means the API stays on the local +#: filesystem cache. +WORKFLOW_S3_BUCKET: Final = os.getenv("WORKFLOW_S3_BUCKET") + +#: Optional key prefix the S3 backend prepends to every cache key. +#: Lets a single bucket host multiple environments (``dev/``, +#: ``staging/``, ``prod/``) without collisions. +WORKFLOW_S3_PREFIX: Final = os.getenv("WORKFLOW_S3_PREFIX", "") + +#: ECS cluster name or ARN. Gates the ECS-backed provisioner and +#: discovery profile; unset means the PoC profile stays on +#: ``LanDiscovery`` with no provisioner. +ECS_CLUSTER: Final = os.getenv("ECS_CLUSTER") + +#: ECS worker task definition. Accepts either a family name +#: (``cfdb-worker``) or a ``family:revision`` string. The provisioner +#: passes it through to ``RunTask`` verbatim; the discovery loop +#: strips any ``:revision`` suffix to derive its +#: ``family`` filter (see ``ECS_WORKER_TASK_FAMILY`` override below). +ECS_WORKER_TASK_DEFINITION: Final = os.getenv("ECS_WORKER_TASK_DEFINITION") + + +def _ecs_default_task_family() -> str | None: + """Derive the discovery ``family`` filter from the task definition. + + ``RunTask`` accepts ``family[:revision]``; ``ListTasks`` accepts + only the family (no revision). The default split strips the + revision when present, with ``ECS_WORKER_TASK_FAMILY`` available as + an explicit override for environments that pin a non-default + family name. + """ + explicit = os.getenv("ECS_WORKER_TASK_FAMILY") + if explicit: + return explicit + if ECS_WORKER_TASK_DEFINITION: + return ECS_WORKER_TASK_DEFINITION.split(":", 1)[0] + return None + + +#: Family used by ``EcsDiscovery`` to filter ``ListTasks``. Derived +#: from ``ECS_WORKER_TASK_DEFINITION`` by default; set explicitly via +#: ``ECS_WORKER_TASK_FAMILY`` only when the discovery family differs +#: from the provisioner task-def family (rare). +ECS_WORKER_TASK_FAMILY: Final = _ecs_default_task_family() + + +def _parse_csv_env(name: str, default: str = "") -> list[str]: + """Parse a comma-separated env var into a list of trimmed strings. + + Empty strings are dropped so a trailing comma or double comma + doesn't propagate as an empty subnet/SG entry that boto3 would + later reject with a less informative error. + """ + raw = os.getenv(name, default) + return [item.strip() for item in raw.split(",") if item.strip()] + + +#: Awsvpc subnet IDs the worker ENIs land in. Required for the ECS +#: profile; an empty list with ``ECS_CLUSTER`` set is a misconfiguration +#: that the lifespan refuses to start under. +ECS_WORKER_SUBNETS: Final = _parse_csv_env("ECS_WORKER_SUBNETS") + +#: Awsvpc security groups attached to the worker ENIs. Optional — +#: when empty, ECS applies the VPC default SG. +ECS_WORKER_SECURITY_GROUPS: Final = _parse_csv_env("ECS_WORKER_SECURITY_GROUPS") + +_ASSIGN_PUBLIC_IP_VALUES = frozenset({"ENABLED", "DISABLED"}) + + +def _parse_assign_public_ip(name: str, default: str) -> str: + """Parse an ECS ``assignPublicIp`` env var with explicit validation. + + ECS rejects anything other than ``ENABLED`` / ``DISABLED``; we + surface the misconfiguration at module-import time so the lifespan + doesn't get to Mongo + S3 init before tripping on it. Pairs with + :data:`cfdb.workflows.provisioner._ASSIGN_PUBLIC_IP_VALUES`. + """ + raw = os.getenv(name) + if raw is None or raw == "": + return default + if raw not in _ASSIGN_PUBLIC_IP_VALUES: + raise ImportError( + f"Environment variable {name}={raw!r} must be one of " + f"{sorted(_ASSIGN_PUBLIC_IP_VALUES)}" + ) + return raw + + +#: Whether the worker ENI gets a public IPv4 address. Production +#: should leave this DISABLED and reach AWS via VPC endpoints; +#: LocalStack accepts either value. +ECS_WORKER_ASSIGN_PUBLIC_IP: Final = _parse_assign_public_ip( + "ECS_WORKER_ASSIGN_PUBLIC_IP", "DISABLED" +) + +#: Per-deployment override for the workflow runtime cap. Threads +#: through to :class:`cfdb.workflows.executor.WoolExecutor` via the +#: lifespan. Defaults to :data:`cfdb.workflows.WORKFLOW_DURATION_CAP_S` +#: (the env-driven workflows-package default, currently 4 h). +WORKFLOW_DURATION_CAP_SECONDS: Final = _parse_int_env( + "WORKFLOW_DURATION_CAP_SECONDS", default=WORKFLOW_DURATION_CAP_S, minimum=1 +) + db: AsyncIOMotorDatabase | None = None cache: "CacheBackend | None" = None executor: "JobExecutor | None" = None diff --git a/src/cfdb/api/main.py b/src/cfdb/api/main.py index 62ca7c6..bf31c8b 100644 --- a/src/cfdb/api/main.py +++ b/src/cfdb/api/main.py @@ -4,6 +4,7 @@ import re from contextlib import asynccontextmanager from pathlib import Path +from typing import AsyncIterator, Optional import wool from wool.runtime.discovery.lan import LanDiscovery @@ -19,12 +20,19 @@ from cfdb.api.routers.index import router as index_router from cfdb.api.routers.jobs import router as jobs_router from cfdb.api.routers.sync import router as sync_router -from cfdb.workflows.cache import LocalFsCache +from cfdb.workflows.cache import ( + CacheBackend, + LocalFsCache, + S3Cache, + check_s3_bucket_or_raise, +) +from cfdb.workflows.discovery import EcsDiscovery from cfdb.workflows.executor import WoolExecutor from cfdb.workflows.models import ACTIVE_STATUSES from cfdb.workflows.processors.bam import BamIndexProcessor from cfdb.workflows.processors.registry import default_registry from cfdb.workflows.processors.tabix import TabixIntervalProcessor +from cfdb.workflows.provisioner import EcsProvisioner logging.basicConfig(level=logging.INFO) @@ -86,6 +94,91 @@ async def _assert_jobs_indexes(db, log: logging.Logger) -> None: ) +async def _build_cache(cache_root: Path) -> CacheBackend: + """Pick the cache backend per env config. + + Returns ``S3Cache`` when ``WORKFLOW_S3_BUCKET`` is set (production + and LocalStack-backed dev) and probes the bucket via + :func:`check_s3_bucket_or_raise` so a typo / missing IAM grant + fails fast at boot rather than as a permanent-cache-miss cascade + once workflows start. Falls back to ``LocalFsCache`` otherwise — + the PoC profile that ships before the deployment substrate is + wired up. ``cache_root`` is only meaningful on the LocalFsCache + branch but is accepted unconditionally so the caller doesn't need + to branch on the env state. + """ + if api.WORKFLOW_S3_BUCKET: + cache = S3Cache( + bucket=api.WORKFLOW_S3_BUCKET, + prefix=api.WORKFLOW_S3_PREFIX, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) + # Reuse the freshly-built client so the probe targets the same + # endpoint as subsequent reads/writes. + await check_s3_bucket_or_raise(api.WORKFLOW_S3_BUCKET, client=cache._client) + return cache + return LocalFsCache(cache_root) + + +def _maybe_build_provisioner() -> Optional[EcsProvisioner]: + """Build an ``EcsProvisioner`` when the ECS env is fully configured. + + Returns ``None`` when any required knob is missing so the executor's + PoC profile (manually-started wool workers, no provisioner) is + preserved. The lifespan logs which path is selected so an + incomplete config doesn't silently degrade. + + The completeness check requires all three of cluster, task + definition, and a non-empty subnet list — these are the minimum + ``RunTask`` needs to launch a worker into the awsvpc network. A + partial config (e.g. cluster + task-def but no subnets) is treated + as the PoC profile rather than failing fast because deployment + pipelines may roll out the env vars in stages. + """ + if not (api.ECS_CLUSTER and api.ECS_WORKER_TASK_DEFINITION and api.ECS_WORKER_SUBNETS): + return None + return EcsProvisioner( + cluster=api.ECS_CLUSTER, + task_definition=api.ECS_WORKER_TASK_DEFINITION, + subnets=api.ECS_WORKER_SUBNETS, + security_groups=api.ECS_WORKER_SECURITY_GROUPS, + assign_public_ip=api.ECS_WORKER_ASSIGN_PUBLIC_IP, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) + + +@asynccontextmanager +async def _build_discovery() -> AsyncIterator[object]: + """Pick the worker-discovery backend per env config. + + Two paths: + * ECS — ``EcsDiscovery`` polls ``ListTasks`` + ``DescribeTasks`` + on the configured cluster. Async-context-managed: the background + poller starts on ``__aenter__`` and is cancelled on + ``__aexit__``. Selected when ``ECS_CLUSTER`` and the worker + task definition are both set. + * LAN — ``LanDiscovery`` over zeroconf/mDNS. The wool pool + accepts it as-is; the wrapper context manager is a no-op so + the caller's ``async with`` works either way. + + Yields the same shape (something wool's ``WorkerPool`` accepts via + its ``discovery=`` arg) so the lifespan can wire it through without + branching at the call site. + """ + if api.ECS_CLUSTER and api.ECS_WORKER_TASK_FAMILY: + async with EcsDiscovery( + cluster=api.ECS_CLUSTER, + task_definition_family=api.ECS_WORKER_TASK_FAMILY, + endpoint_url=api.AWS_ENDPOINT_URL, + region_name=api.AWS_REGION, + ) as discovery: + yield discovery + return + yield LanDiscovery(api.WORKFLOW_POOL_NAMESPACE) + + def create_mongodb_client() -> AsyncIOMotorClient: """Create MongoDB client with optional TLS authentication.""" log = logging.getLogger(__name__) @@ -146,79 +239,102 @@ async def lifespan(_: FastAPI): # same filesystem (otherwise the kernel raises # ``OSError(EXDEV)``). Verify the precondition at startup so # a multi-volume deployment fails fast with a clear message - # instead of dying mid-pipeline on the first cache.put. - cache_st = os.stat(cache_root) - workdir_st = os.stat(workdir_root) - if cache_st.st_dev != workdir_st.st_dev: - raise RuntimeError( - "SYNC_DATA_DIR subdirectories must share a filesystem " - f"(cache={cache_root!s} st_dev={cache_st.st_dev}, " - f"workdir={workdir_root!s} st_dev={workdir_st.st_dev}). " - "LocalFsCache.put relies on os.replace atomicity; " - "cross-device renames raise OSError(EXDEV). Mount both " - "paths under a single volume or set SYNC_DATA_DIR to a " - "parent that contains both." - ) + # instead of dying mid-pipeline on the first cache.put. Only + # the LocalFsCache branch needs this — S3Cache.put goes + # over the network and has no rename-atomicity requirement. + if not api.WORKFLOW_S3_BUCKET: + cache_st = os.stat(cache_root) + workdir_st = os.stat(workdir_root) + if cache_st.st_dev != workdir_st.st_dev: + raise RuntimeError( + "SYNC_DATA_DIR subdirectories must share a filesystem " + f"(cache={cache_root!s} st_dev={cache_st.st_dev}, " + f"workdir={workdir_root!s} st_dev={workdir_st.st_dev}). " + "LocalFsCache.put relies on os.replace atomicity; " + "cross-device renames raise OSError(EXDEV). Mount both " + "paths under a single volume or set SYNC_DATA_DIR to a " + "parent that contains both." + ) - api.cache = LocalFsCache(cache_root) + api.cache = await _build_cache(cache_root) api.processor_registry = default_registry() api.processor_registry.register(BamIndexProcessor()) api.processor_registry.register(TabixIntervalProcessor()) + provisioner = _maybe_build_provisioner() + # Warn loudly when discovery is ECS-mode but the + # provisioner is unavailable — workers polled-for are not + # workers launched-on-demand, so the dispatch retry budget + # eventually exhausts with ``NoWorkersAvailable``. The + # asymmetry is intentional to support staged env rollouts + # but operators should know they're in a half-configured + # deployment. + if api.ECS_CLUSTER and api.ECS_WORKER_TASK_FAMILY and provisioner is None: + log.warning( + "ECS discovery is active but no provisioner — set " + "ECS_WORKER_TASK_DEFINITION and ECS_WORKER_SUBNETS to " + "complete the ECS profile, or unset ECS_CLUSTER to " + "fall back to the PoC profile." + ) + # Lease workers from the surrounding pool rather than spawning - # them in-process. In production the workers run as separate - # ECS tasks discovered via wool's discovery layer; the API's - # job is to dispatch to whatever capacity exists. Scaling the - # ECS service is out of band (e.g., on ``NoWorkersAvailable`` - # bursts or on a queue-depth metric). + # them in-process. In production the workers run as ECS Fargate + # tasks launched on demand by ``EcsProvisioner`` and discovered + # via ``EcsDiscovery``'s poll-and-diff over ``ListTasks`` / + # ``DescribeTasks``. In the PoC dev profile no AWS env is set + # and the API falls back to ``LanDiscovery`` (zeroconf/mDNS) + # against a manually-started wool pool. # # The explicit ``discovery=`` is required to keep wool out of # its default ephemeral mode — ``WorkerPool(lease=N)`` alone # falls into the ``(spawn=None, discovery=None)`` branch which - # spawns CPU-count workers locally. Pairing ``lease=N`` with a - # shared ``LanDiscovery`` namespace puts the pool in - # discovery-only mode (no spawning); the worker-pool process - # publishes workers via the same namespace. ``LanDiscovery`` - # rides over zeroconf/mDNS, which sidesteps the macOS - # ``watchdog``/FSEvents fork-unsafety we hit with - # ``LocalDiscovery``. - async with wool.WorkerPool( - discovery=LanDiscovery(api.WORKFLOW_POOL_NAMESPACE), - lease=api.WORKFLOW_WORKER_COUNT, - ): - # Snapshot the lifespan task's contextvars after the - # pool's ``__aenter__`` has populated wool's internals. - api.wool_context = contextvars.copy_context() - api.executor = WoolExecutor( - api.db, - api.cache, - cache_root, - api.processor_registry, - workdir_root=workdir_root, - ) - executor_handle = api.executor - log.info( - "Workflow subsystem enabled: cache=%s workdir=%s " - "lease=%d namespace=%s", - cache_root, - workdir_root, - api.WORKFLOW_WORKER_COUNT, - api.WORKFLOW_POOL_NAMESPACE, - ) - try: - yield - finally: - drained = await executor_handle.drain( - timeout=SHUTDOWN_DRAIN_TIMEOUT_SECONDS + # spawns CPU-count workers locally. + async with _build_discovery() as discovery: + async with wool.WorkerPool( + discovery=discovery, + lease=api.WORKFLOW_WORKER_COUNT, + ): + # Snapshot the lifespan task's contextvars after the + # pool's ``__aenter__`` has populated wool's internals. + api.wool_context = contextvars.copy_context() + api.executor = WoolExecutor( + api.db, + api.cache, + cache_root, + api.processor_registry, + workdir_root=workdir_root, + workflow_duration_cap_seconds=api.WORKFLOW_DURATION_CAP_SECONDS, + provisioner=provisioner, + ) + executor_handle = api.executor + log.info( + "Workflow subsystem enabled: cache=%s workdir=%s " + "lease=%d discovery=%s provisioner=%s", + type(api.cache).__name__, + workdir_root, + api.WORKFLOW_WORKER_COUNT, + type(discovery).__name__, + "EcsProvisioner" if provisioner is not None else "none", ) - if drained: - log.info( - "Drained %d workflow task(s) on shutdown", drained + try: + yield + finally: + drained = await executor_handle.drain( + timeout=SHUTDOWN_DRAIN_TIMEOUT_SECONDS ) - api.executor = None - api.cache = None - api.processor_registry = None - api.wool_context = None + if drained: + log.info( + "Drained %d workflow task(s) on shutdown", drained + ) + # Cancel in-flight ``RunTask`` calls so a + # shutdown mid-launch doesn't leave a task + # un-requested-but-billed. + if provisioner is not None: + await provisioner.aclose() + api.executor = None + api.cache = None + api.processor_registry = None + api.wool_context = None else: log.info( "SYNC_DATA_DIR unset — workflow subsystem disabled; " diff --git a/src/cfdb/workflows/__init__.py b/src/cfdb/workflows/__init__.py index 0776c68..5088e82 100644 --- a/src/cfdb/workflows/__init__.py +++ b/src/cfdb/workflows/__init__.py @@ -27,7 +27,10 @@ - ``CFDB_WORKFLOW_DURATION_CAP_S`` — per-workflow wall-clock cap, enforced via ``asyncio.timeout`` on the API side while consuming the routine's - event stream. Default ``1200`` (20 min). + event stream. Default ``14400`` (4 h) — sized for multi-hour + preprocessing runs (e.g., a ``samtools sort`` on a multi-GB BAM + followed by ``samtools index``). Operators running on bounded fixtures + in dev should lower this via env. - ``CFDB_WORKFLOW_DISPATCH_WAIT_S`` — how long ``ensure_workflow`` waits for a wool worker to become available before giving up. Default ``60``. - ``CFDB_WORKFLOW_HEARTBEAT_INTERVAL_S`` — how often the routine emits a @@ -109,7 +112,7 @@ def _positive_int(name: str, value: str, *, minimum: int = 0) -> int: # look like a hard failure). WORKFLOW_DURATION_CAP_S: Final = _positive_int( "CFDB_WORKFLOW_DURATION_CAP_S", - os.getenv("CFDB_WORKFLOW_DURATION_CAP_S", "1200"), + os.getenv("CFDB_WORKFLOW_DURATION_CAP_S", "14400"), minimum=1, ) WORKFLOW_DISPATCH_WAIT_S: Final = _positive_int( diff --git a/src/cfdb/workflows/cache.py b/src/cfdb/workflows/cache.py index 5787038..f675aa6 100644 --- a/src/cfdb/workflows/cache.py +++ b/src/cfdb/workflows/cache.py @@ -1,9 +1,10 @@ """Pluggable cache backend for workflow artifacts. The cache is a byte-range-aware content store keyed by strings produced by -``workflows.keys.cache_key``. ``LocalFsCache`` is the concrete backend used -in local development and for the initial CVH rollout; an S3-backed -implementation with the same interface is planned for production. +``workflows.keys.cache_key``. Two concrete backends ship: ``LocalFsCache`` +for development and unit tests that don't need a network round-trip, and +``S3Cache`` for production (and for LocalStack-backed dev that mirrors +production end-to-end via boto3). Range-aware reads matter because Gosling's client-side fetchers (BAM/tabix families, bbi) issue ``Range: bytes=…`` requests against the artifact URL. @@ -19,7 +20,10 @@ from collections.abc import AsyncIterator from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Any, Optional + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError @dataclass(frozen=True) @@ -74,17 +78,29 @@ async def delete(self, key: str) -> bool: _CHUNK_SIZE = 1 << 16 # 64 KiB -def _safe_key_path(root: Path, key: str) -> Path: - """Resolve a cache key to an absolute path under ``root``. +def _validate_cache_key(key: str) -> None: + """Reject path-traversal segments in a cache key. Rejects empty keys and keys containing path-traversal segments so that malformed input cannot escape the cache root or collapse onto the - root directory itself. + root directory itself. Shared by both backends so the rule (and its + error message) stays in one place. """ if not key or not key.strip("/"): raise ValueError(f"Cache key must be non-empty: {key!r}") if ".." in key.split("/"): raise ValueError(f"Invalid cache key: {key!r}") + if key.startswith("/"): + raise ValueError(f"Cache key must not start with '/': {key!r}") + + +def _safe_key_path(root: Path, key: str) -> Path: + """Resolve a cache key to an absolute path under ``root``. + + Validates the key shape and verifies the resolved path stays under + ``root`` so a malformed key cannot escape the cache root. + """ + _validate_cache_key(key) path = (root / key).resolve() root_resolved = root.resolve() if root_resolved not in path.parents: @@ -95,8 +111,9 @@ def _safe_key_path(root: Path, key: str) -> Path: class LocalFsCache(CacheBackend): """Local-filesystem backend. Keys map directly to relative paths. - ``put`` writes via a temp sibling + atomic rename. ``get`` supports - byte-range reads and chunks by 64 KiB to keep streaming memory-bounded. + ``put`` atomically renames the caller's source file into place via + ``os.replace``. ``get`` supports byte-range reads and chunks by 64 + KiB to keep streaming memory-bounded. """ def __init__(self, root: Path) -> None: @@ -143,6 +160,293 @@ async def delete(self, key: str) -> bool: return False +class S3Cache(CacheBackend): + """Boto3-backed cache for production (and LocalStack-backed dev). + + The same code targets real S3 and LocalStack — the only difference + is the ``endpoint_url`` passed to ``boto3.client("s3")``. Keys are + stored as object keys (optionally under a configurable prefix); + range reads use S3's ``Range`` header verbatim, so the fetcher + semantics are identical to ``LocalFsCache``. + + Args: + bucket: Bucket name. Must already exist (LocalStack and prod + both treat bucket creation as an out-of-band concern). + prefix: Optional key prefix; useful for sharing a single bucket + across multiple environments. Empty string by default. + client: Optional pre-built boto3 ``s3`` client. When omitted, + one is constructed via :func:`_build_s3_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + Tests inject a moto-backed client through this argument. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`_build_s3_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here so + LocalStack vs production differ only at this seam. + region_name: Boto3 ``region_name``. Plumbed analogously. + chunk_size: Streaming chunk size for ``get`` reads. Defaults to + 64 KiB to match the local filesystem backend. + """ + + def __init__( + self, + bucket: str, + *, + prefix: str = "", + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + chunk_size: int = _CHUNK_SIZE, + ) -> None: + if not bucket: + raise ValueError("S3Cache requires a non-empty bucket name") + self._bucket = bucket + # Normalize so callers can pass either ``"prefix"`` or ``"prefix/"``. + self._prefix = prefix.strip("/") + self._endpoint_url = endpoint_url + self._region_name = region_name + self._client = ( + client + if client is not None + else _build_s3_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._chunk_size = chunk_size + + def _object_key(self, key: str) -> str: + """Apply the configured prefix to a validated cache key.""" + _validate_cache_key(key) + return f"{self._prefix}/{key}" if self._prefix else key + + def __getstate__(self) -> dict[str, Any]: + """Strip the boto3 client for pickling. + + ``S3Cache`` is dispatched across the cloudpickle boundary into + Wool worker processes; botocore's ``BaseClient`` cannot be + pickled. ``__setstate__`` rebuilds it via ``_build_s3_client`` + during unpickling, threading the originally-supplied + ``endpoint_url`` / ``region_name`` through so the worker + targets the same backend as the API process. + """ + state = self.__dict__.copy() + state["_client"] = None + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restore state and rebuild the boto3 client on the worker.""" + self.__dict__.update(state) + if self._client is None: + self._client = _build_s3_client( + endpoint_url=self._endpoint_url, + region_name=self._region_name, + ) + + async def head(self, key: str) -> Optional[CacheEntry]: + """Return cache metadata for ``key``, or None if the object is absent. + + ``ClientError`` covers HTTP-level S3 errors with structured + response codes; ``BotoCoreError`` covers transport / cred + failures with no ``.response``. ``_is_not_found`` returns + False for the latter family, so transport failures correctly + re-raise rather than masquerade as cache miss. + + Note: S3 ``HEAD`` responses carry no body, so a missing + bucket is indistinguishable from a missing object at this + endpoint — both surface as a bare ``404``. The lifespan + startup probes the bucket separately (see + ``check_s3_bucket_or_raise``) so a typo in + ``WORKFLOW_S3_BUCKET`` fails fast at boot rather than as a + cascade of "permanent cache miss" symptoms. + """ + object_key = self._object_key(key) + try: + response = await asyncio.to_thread( + self._client.head_object, + Bucket=self._bucket, + Key=object_key, + ) + except (ClientError, BotoCoreError) as exc: + if _is_not_found(exc): + return None + raise + return CacheEntry(key=key, size=int(response["ContentLength"])) + + def get( + self, key: str, byte_range: Optional[tuple[int, int]] = None + ) -> AsyncIterator[bytes]: + """Stream cached bytes from S3, optionally restricted to a byte range. + + ``GetObject`` accepts an inclusive ``Range`` header; we forward + the tuple verbatim. A missing object yields an empty iterator, + matching ``LocalFsCache`` semantics so router code can treat + cache misses uniformly across backends. + """ + return _stream_s3_object( + self._client, + self._bucket, + self._object_key(key), + byte_range, + self._chunk_size, + ) + + async def put(self, key: str, source_path: Path) -> CacheEntry: + """Upload ``source_path`` to S3 under the configured key. + + ``upload_file`` is atomic from the reader's perspective — + readers either see the prior object (if any) or the new one, + never a partial write. Boto3 streams the file in a thread so + the event loop isn't blocked. Size is taken from the source + file rather than a follow-up ``HEAD`` to avoid the round-trip; + when the source file has been torn down between upload + completion and the local ``stat`` (workdir cleanup races + finalization), we fall back to a ``head_object`` round-trip + rather than surface ``FileNotFoundError`` for an upload that + is already committed. + """ + object_key = self._object_key(key) + await asyncio.to_thread( + self._client.upload_file, + str(source_path), + self._bucket, + object_key, + ) + try: + size = await asyncio.to_thread(source_path.stat) + except FileNotFoundError: + response = await asyncio.to_thread( + self._client.head_object, + Bucket=self._bucket, + Key=object_key, + ) + return CacheEntry(key=key, size=int(response["ContentLength"])) + return CacheEntry(key=key, size=size.st_size) + + async def delete(self, key: str) -> bool: + """Delete the cache entry. Returns True when the object existed. + + ``DeleteObject`` returns 204 whether or not the key existed, so + we probe with HEAD first to give callers the existence signal. + The HEAD/DELETE pair is non-atomic — a concurrent ``put`` between + them returns ``False`` ("did not exist") yet erases the freshly- + uploaded object. Callers MUST serialize ``put``/``delete`` for the + same key via the workflow mutex; the cache backend itself does + not arbitrate concurrent writers. + """ + object_key = self._object_key(key) + existed = await self.head(key) is not None + await asyncio.to_thread( + self._client.delete_object, + Bucket=self._bucket, + Key=object_key, + ) + return existed + + +def _build_s3_client( + *, endpoint_url: Optional[str] = None, region_name: Optional[str] = None +) -> Any: + """Construct a boto3 ``s3`` client with explicit endpoint/region. + + The caller (typically :class:`S3Cache` or the API lifespan) is the + single source of truth for ``endpoint_url`` and ``region_name`` — + we no longer reach into :mod:`cfdb.api` for fallback values. Pass + :data:`cfdb.api.AWS_ENDPOINT_URL` / :data:`cfdb.api.AWS_REGION` + from the lifespan; leave both ``None`` to let boto3's default + session resolver chain pick them up from the environment. + """ + return boto3.client( + "s3", + endpoint_url=endpoint_url, + region_name=region_name, + ) + + +async def check_s3_bucket_or_raise(bucket: str, *, client: Optional[Any] = None) -> None: + """Verify ``bucket`` is reachable; raise on missing/inaccessible. + + Run from the API lifespan so a typo in ``WORKFLOW_S3_BUCKET`` (or + a missing IAM grant) fails fast at boot rather than masquerading + as a permanent cache-miss cascade once workflows start. ``HEAD`` + on a missing bucket returns a bare ``404`` that ``S3Cache.head`` + cannot distinguish from a missing object — this probe asks + ``head_bucket`` directly, which gets a structured response. + """ + cli = client if client is not None else _build_s3_client() + try: + await asyncio.to_thread(cli.head_bucket, Bucket=bucket) + except (ClientError, BotoCoreError) as exc: + raise RuntimeError( + f"S3 bucket {bucket!r} is not reachable: {type(exc).__name__}: {exc}" + ) from exc + + +#: Object-level "missing" codes only. ``NoSuchBucket`` is deliberately +#: NOT in this set: a missing bucket is a configuration failure (typo +#: in WORKFLOW_S3_BUCKET, missing IAM, region mismatch) that should +#: surface as an exception rather than silently masquerade as a +#: permanent cache miss. +_NOT_FOUND_CODES = frozenset({"404", "NoSuchKey", "NotFound"}) + + +def _is_not_found(exc: BaseException) -> bool: + """Return True when a boto3 exception indicates a missing object. + + ``ClientError`` carries a structured ``response`` dict with both an + ``Error.Code`` (string) and ``ResponseMetadata.HTTPStatusCode`` + (int). We check both: head_object responses use the bare 404 code, + get_object uses the named ``NoSuchKey``. ``BotoCoreError`` (network + / credential failures) has no ``response`` and is correctly + classified as not-a-not-found, so it propagates. + """ + response = getattr(exc, "response", None) + if not isinstance(response, dict): + return False + code = (response.get("Error") or {}).get("Code") + if code in _NOT_FOUND_CODES: + return True + status = (response.get("ResponseMetadata") or {}).get("HTTPStatusCode") + return status == 404 + + +async def _stream_s3_object( + client: Any, + bucket: str, + object_key: str, + byte_range: Optional[tuple[int, int]], + chunk_size: int, +) -> AsyncIterator[bytes]: + """Async generator yielding chunks from an S3 object. + + Diverges from ``_stream_file`` on mid-stream object disappearance: + a deletion between the ``GetObject`` response and the body read + raises a ``ClientError`` to the consumer rather than truncating + silently. Callers serialize put/delete via the workflow mutex + (see ``S3Cache.delete``), so the divergence is benign in practice. + """ + kwargs: dict[str, Any] = {"Bucket": bucket, "Key": object_key} + if byte_range is not None: + start, end = byte_range + kwargs["Range"] = f"bytes={start}-{end}" + try: + response = await asyncio.to_thread(client.get_object, **kwargs) + except (ClientError, BotoCoreError) as exc: + if _is_not_found(exc): + return + raise + body = response["Body"] + try: + while True: + chunk = await asyncio.to_thread(body.read, chunk_size) + if not chunk: + break + yield chunk + finally: + close = getattr(body, "close", None) + if close is not None: + # Shield the close so a CancelledError in this finally + # (consumer disconnect mid-stream, chained cancel) cannot + # skip the underlying urllib3 connection release. + await asyncio.shield(asyncio.to_thread(close)) + + async def _stream_file( path: Path, byte_range: Optional[tuple[int, int]] ) -> AsyncIterator[bytes]: @@ -188,4 +492,7 @@ def _open_and_seek(): remaining -= len(chunk) yield chunk finally: - await asyncio.to_thread(fh.close) + # Shield the close so a CancelledError in this finally (consumer + # disconnect mid-stream, chained cancel) cannot skip the + # underlying file-descriptor release. + await asyncio.shield(asyncio.to_thread(fh.close)) diff --git a/src/cfdb/workflows/constants.py b/src/cfdb/workflows/constants.py new file mode 100644 index 0000000..4d942eb --- /dev/null +++ b/src/cfdb/workflows/constants.py @@ -0,0 +1,17 @@ +"""Constants shared across cfdb.workflows.* modules. + +Lives in its own leaf module so ``discovery`` can read +``DEFAULT_WORKER_PORT`` without importing ``worker_main`` — importing +``worker_main`` would drag aiohttp into ``discovery``'s import graph +and impose a runtime dependency that none of ``discovery``'s actual +consumers (the API lifespan, tests) need. +""" + +from __future__ import annotations + +#: Wool gRPC port the worker container binds and clients dial. +#: ``EcsDiscovery`` uses it to assemble worker addresses; +#: ``worker_main`` uses it as the bind port. The two MUST agree, and +#: the ECS task definition's ``portMappings`` / ``healthCheck`` target +#: MUST match. +DEFAULT_WORKER_PORT = 50051 diff --git a/src/cfdb/workflows/discovery.py b/src/cfdb/workflows/discovery.py new file mode 100644 index 0000000..b160222 --- /dev/null +++ b/src/cfdb/workflows/discovery.py @@ -0,0 +1,523 @@ +"""Worker discovery driven by the ECS control plane. + +ECS already owns the lifecycle of every worker task — registration, IP, +status, health — so we use it as the discovery substrate directly rather +than maintaining a parallel registry. ``EcsDiscovery`` implements Wool's +``DiscoveryLike`` protocol with a poll-and-diff loop: + +1. ``ecs.list_tasks`` enumerates every task in the cluster matching the + worker task family with ``desiredStatus="RUNNING"``. +2. ``ecs.describe_tasks`` (batched 100 ARNs at a time) hydrates each + into its full state, including ``healthStatus`` and ``attachments``. +3. We filter for ``lastStatus == "RUNNING"`` and (where reported) + ``healthStatus == "HEALTHY"`` — tasks whose container definition + omits a ``healthCheck`` are accepted as healthy so a deployment that + hasn't yet wired up health checks still surfaces workers. +4. The poller diffs the resolved set against the previous one and + publishes ``worker-added`` / ``worker-dropped`` events to a Wool + ``DiscoveryPublisherLike``. + +The worker side does nothing for discovery — no heartbeat thread, no +Mongo write, no registration RPC. ECS reports ``healthStatus: HEALTHY`` +once the container's health check passes, which is the "ready to +dispatch" signal the discovery filters on. The worker task definition +MUST declare a ``healthCheck`` against the gRPC port; without one +ECS reports ``healthStatus: UNKNOWN`` indefinitely and the worker is +never advertised. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections.abc import AsyncIterator, Iterable +from typing import Any, Optional + +from wool.runtime.discovery.base import ( + Discovery, + DiscoveryEvent, + DiscoveryEventType, + DiscoveryPublisherLike, + DiscoverySubscriberLike, + PredicateFunction, + WorkerMetadata, +) + +from cfdb.workflows.constants import DEFAULT_WORKER_PORT +from cfdb.workflows.provisioner import build_ecs_client + +logger = logging.getLogger(__name__) + +#: ECS ``DescribeTasks`` accepts up to 100 ARNs per call. +_DESCRIBE_BATCH_SIZE = 100 + +#: Default poll cadence. ECS ``ListTasks`` is a Cluster resource read +#: action: burst 100, sustained 20 req/s per region (shared bucket with +#: ``DescribeTasks``). 5s leaves ample headroom for many concurrent +#: clusters. +DEFAULT_POLL_INTERVAL_SECONDS = 5.0 + + +class EcsDiscovery(Discovery): + """Wool-compatible discovery service backed by ECS list/describe. + + Args: + cluster: ECS cluster name or ARN to poll. + task_definition_family: Worker task definition family (sans + revision). Used as the ``family`` filter on ``ListTasks``. + client: Optional pre-built boto3 ``ecs`` client. When omitted, + one is constructed via :func:`build_ecs_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`build_ecs_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here. + region_name: Boto3 ``region_name``. Plumbed analogously. + poll_interval: Seconds between successive ``ListTasks`` polls. + worker_port: gRPC port the worker binds — used to construct + the address string passed to Wool. Shares the worker_main + default so a deployment that changes one changes both. + version: Wool worker version string passed through into + ``WorkerMetadata``. Free-form; useful for filtering. + + Lifecycle: enter ``async with EcsDiscovery(...)`` to start the + background poller. Exiting the context cancels the poller and + releases the publisher. + """ + + def __init__( + self, + *, + cluster: str, + task_definition_family: str, + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS, + worker_port: int = DEFAULT_WORKER_PORT, + version: str = "0", + ) -> None: + if not cluster: + raise ValueError("EcsDiscovery requires a cluster name") + if not task_definition_family: + raise ValueError("EcsDiscovery requires a task_definition_family") + if poll_interval <= 0: + raise ValueError("poll_interval must be positive") + + self._cluster = cluster + self._task_definition_family = task_definition_family + self._client = ( + client + if client is not None + else build_ecs_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._poll_interval = poll_interval + self._worker_port = worker_port + self._version = version + + self._subscribers: list[_EcsSubscriber] = [] + self._known: dict[str, WorkerMetadata] = {} + self._poll_task: Optional[asyncio.Task[None]] = None + self._lock = asyncio.Lock() + # Separate from ``_lock`` so subscriber registration is not + # blocked on the AWS round-trip held by an in-flight poll. + # Serializes the entire ``poll_once`` body so two concurrent + # callers (test fixture + background loop, mostly) cannot + # interleave their list/describe snapshots into the diff and + # regress ``_known``. + self._poll_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + async def __aenter__(self) -> "EcsDiscovery": + # Re-entry would orphan the prior poll task and diff against + # stale ``_known`` state. ``__aexit__`` nulls ``_poll_task``; + # the assert pairs with that to make the misuse loud. + assert self._poll_task is None, ( + "EcsDiscovery context entered twice; create a fresh instance" + ) + # Run an immediate scan so subscribers attached right after + # __aenter__ see the initial set of workers without waiting a + # full poll interval. + await self.poll_once() + self._poll_task = asyncio.create_task(self._poll_loop()) + return self + + async def __aexit__(self, *_exc: Any) -> None: + if self._poll_task is not None: + self._poll_task.cancel() + try: + await self._poll_task + except asyncio.CancelledError: + pass + self._poll_task = None + # Wake any consumer parked on ``await self._queue.get()`` so + # ``async for event in subscriber:`` exits cleanly rather than + # blocking forever on a queue nothing will publish to again. + # The sentinel ``None`` traverses each subscriber's queue and + # ``_iter`` breaks its loop on it. + async with self._lock: + for sub in self._subscribers: + await sub._queue.put(None) + # Clear cached state so a stray re-subscribe after exit + # doesn't replay stale workers. + self._known.clear() + self._subscribers.clear() + + # ------------------------------------------------------------------ + # Wool DiscoveryLike protocol + # ------------------------------------------------------------------ + @property + def publisher(self) -> DiscoveryPublisherLike: + """``EcsDiscovery`` is read-only — workers register implicitly via ECS. + + Returns a guard-rail publisher that raises on ``publish``; + production code should never reach for this property. + """ + return _RaisingPublisher() + + @property + def subscriber(self) -> DiscoverySubscriberLike: + """A subscriber with no filter (sees every healthy worker).""" + return self.subscribe() + + def subscribe( + self, filter: Optional[PredicateFunction] = None + ) -> DiscoverySubscriberLike: + """Create a subscriber, optionally filtered by a worker predicate. + + The filter predicate runs inline on the dispatch path while the + registration lock is held — keep it cheap and non-blocking. + Slow predicates stall delivery to every other subscriber and + block concurrent registrations. + """ + return _EcsSubscriber(self, filter) + + # ------------------------------------------------------------------ + # Polling + # ------------------------------------------------------------------ + async def poll_once(self) -> tuple[list[DiscoveryEvent], dict[str, WorkerMetadata]]: + """Run one list/describe cycle and emit diff events. + + Returns the events emitted in this cycle and the resolved set of + currently-known healthy workers (keyed by UUID hex string). + Callers do not normally need to use the return value — it + exists for tests so that polling can be exercised step-by-step + without the background loop. + + ``self._poll_lock`` serializes the entire body so two concurrent + callers (test fixture + background loop) cannot interleave + their list/describe snapshots into the diff and regress + ``_known``. ``self._lock`` separately serializes the + diff/mutate/dispatch tail against subscriber registration; see + ``_register_subscriber`` for the pairing. + + On ``list_tasks`` / ``describe_tasks`` failure the previous + ``_known`` snapshot is retained — workers added during a multi- + cycle ECS outage are invisible until recovery, and workers + terminated during the outage continue to be advertised as + healthy until the next successful poll. Acceptable degradation + next to the cold-start budget; preferable to flapping the whole + fleet on a single transient failure. + """ + async with self._poll_lock: + try: + arns = await asyncio.to_thread(self._list_task_arns) + except Exception: + logger.exception( + "EcsDiscovery: list_tasks failed for cluster=%s", + self._cluster, + ) + async with self._lock: + return [], dict(self._known) + + resolved: dict[str, WorkerMetadata] = {} + if arns: + try: + tasks = await asyncio.to_thread(self._describe_tasks_batched, arns) + except Exception: + logger.exception( + "EcsDiscovery: describe_tasks failed for cluster=%s", + self._cluster, + ) + async with self._lock: + return [], dict(self._known) + for task in tasks: + metadata = self._task_to_metadata(task) + if metadata is not None: + resolved[str(metadata.uid)] = metadata + + # Diff, mutate, and dispatch under ``self._lock`` so a + # concurrent ``_register_subscriber`` cannot replay a + # snapshot that already contains a freshly-added worker AND + # then receive the worker-added event for it from this + # dispatch (duplicate delivery). Either the registration + # runs entirely before this mutation (replay sees the OLD + # known, this dispatch sees the new subscriber) or entirely + # after (replay sees the NEW known, this dispatch's events + # were already delivered). + async with self._lock: + events = list(_diff(self._known, resolved)) + self._known = resolved + if events: + for sub in self._subscribers: + for event in events: + await sub._publish(event) + + return events, dict(resolved) + + async def _poll_loop(self) -> None: + """Background poll loop. Cancellation-safe.""" + while True: + try: + await asyncio.sleep(self._poll_interval) + await self.poll_once() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("EcsDiscovery poll loop iteration failed") + + # ------------------------------------------------------------------ + # ECS API plumbing + # ------------------------------------------------------------------ + def _list_task_arns(self) -> list[str]: + """Page through ``ListTasks`` and return every task ARN.""" + arns: list[str] = [] + paginator_kwargs = { + "cluster": self._cluster, + "family": self._task_definition_family, + "desiredStatus": "RUNNING", + } + next_token: Optional[str] = None + while True: + kwargs = dict(paginator_kwargs) + if next_token: + kwargs["nextToken"] = next_token + response = self._client.list_tasks(**kwargs) + arns.extend(response.get("taskArns") or []) + next_token = response.get("nextToken") + if not next_token: + break + return arns + + def _describe_tasks_batched(self, arns: list[str]) -> list[dict[str, Any]]: + """Call ``DescribeTasks`` in batches of ``_DESCRIBE_BATCH_SIZE``.""" + out: list[dict[str, Any]] = [] + for i in range(0, len(arns), _DESCRIBE_BATCH_SIZE): + batch = arns[i : i + _DESCRIBE_BATCH_SIZE] + response = self._client.describe_tasks( + cluster=self._cluster, + tasks=batch, + ) + out.extend(response.get("tasks") or []) + return out + + def _task_to_metadata(self, task: dict[str, Any]) -> Optional[WorkerMetadata]: + """Convert an ECS task description to a Wool ``WorkerMetadata``. + + Returns None when the task is not RUNNING + HEALTHY or we + cannot extract a usable IP address. Worker task definitions + MUST declare a ``healthCheck``; tasks without one surface as + ``healthStatus: UNKNOWN`` and are filtered out, since their + gRPC readiness is unknowable. Tasks whose ``healthStatus`` is + absent from the describe-tasks response are also filtered — + same reason. + """ + if task.get("lastStatus") != "RUNNING": + return None + if task.get("healthStatus") != "HEALTHY": + return None + + ip = _extract_eni_ip(task) + if ip is None: + return None + + # ECS task ARNs end in ``/`` where ```` is a + # 32-char hex string. We use it as a stable UUID for the + # worker so successive polls produce identical metadata + # (otherwise diff would emit add+drop on every cycle). A + # missing ARN means we cannot identify the task — drop it + # rather than collide every ARN-less task on the same UUID. + task_arn = task.get("taskArn") or "" + if not task_arn: + return None + task_id = task_arn.rsplit("/", 1)[-1] + try: + uid = uuid.UUID(task_id) + except (ValueError, AttributeError): + uid = uuid.uuid5(uuid.NAMESPACE_URL, task_arn) + + return WorkerMetadata( + uid=uid, + address=f"{ip}:{self._worker_port}", + pid=0, + version=self._version, + ) + + async def _register_subscriber(self, sub: "_EcsSubscriber") -> None: + """Append a subscriber and prime its queue with the current snapshot. + + Both steps happen under ``self._lock``, paired with + ``poll_once`` mutating ``_known`` and dispatching under the + same lock: a concurrent poll either runs entirely before the + registration (replay sees the OLD ``_known``, the poll's + events go to the new subscriber via dispatch) or entirely + after (replay sees the NEW ``_known``, no dispatch overlaps). + No interleave produces a missed event or a duplicate ``worker- + added`` for the same UID. + """ + async with self._lock: + self._subscribers.append(sub) + for metadata in self._known.values(): + await sub._publish(DiscoveryEvent("worker-added", metadata=metadata)) + + async def _unregister_subscriber(self, sub: "_EcsSubscriber") -> None: + """Remove ``sub`` from the subscriber list under the lock. + + Silently ignores already-removed subscribers so the iteration + ``finally`` is idempotent. + """ + async with self._lock: + try: + self._subscribers.remove(sub) + except ValueError: + pass + + +class _EcsSubscriber: + """Discovery subscriber backed by an ``asyncio.Queue``. + + The discovery instance pushes events into the queue; consumers + iterate via ``async for``. Subscribers are single-use — a second + iteration after the first has finished raises ``RuntimeError``; + create a fresh subscriber via ``EcsDiscovery.subscribe()`` instead. + """ + + def __init__( + self, + owner: EcsDiscovery, + filter: Optional[PredicateFunction], + ) -> None: + self._owner = owner + self._filter = filter + # The queue carries discovery events plus a single ``None`` + # sentinel pushed by ``EcsDiscovery.__aexit__`` to wake a + # consumer that would otherwise block forever on ``get()`` + # after the discovery has shut down. + self._queue: asyncio.Queue[Optional[DiscoveryEvent]] = asyncio.Queue() + self._exhausted = False + + def __aiter__(self) -> AsyncIterator[DiscoveryEvent]: + return self._iter() + + async def _publish(self, event: DiscoveryEvent) -> None: + """Deliver ``event`` to the subscriber's queue, or drop it. + + Called by ``EcsDiscovery.poll_once`` while holding the + discovery lock. The filter (if any) runs inline; events that + don't pass it are dropped silently — the subscriber asked not + to see this worker. Otherwise the event is enqueued on the + subscriber's unbounded ``asyncio.Queue`` and surfaces via + ``_iter`` on the next consumer ``__anext__``. + """ + if self._filter is not None and not self._filter(event.metadata): + return + await self._queue.put(event) + + async def _iter(self) -> AsyncIterator[DiscoveryEvent]: + """Yield discovery events for the lifetime of this subscriber. + + Single-use: a second call after the iterator has been consumed + raises ``RuntimeError``. ``EcsDiscovery.subscribe()`` returns + a fresh subscriber per call, so multiple independent consumers + are supported by creating multiple subscribers. + + Registers with the owning discovery on the first iteration so + the snapshot of currently-known workers is replayed before + any new poll events arrive. The ``finally`` always + unregisters — ``_unregister_subscriber`` silently ignores + already-removed entries, so a partial-registration failure + does not leak a stale subscriber. + + The ``_exhausted`` flag is flipped synchronously *before* the + first ``await`` so two ``__aiter__()`` calls on the same + subscriber (e.g. a framework that calls ``aiter()`` for + introspection before iterating) cannot both register and + share the underlying queue. CPython's single-threaded asyncio + guarantees the first generator's body runs up to its first + await before any other generator starts; the second + generator's check fires on its first ``__anext__``. + """ + if self._exhausted: + raise RuntimeError( + "EcsDiscovery subscriber already iterated; " + "call EcsDiscovery.subscribe() for a fresh one" + ) + self._exhausted = True + try: + await self._owner._register_subscriber(self) + while True: + event = await self._queue.get() + if event is None: + # Sentinel from ``EcsDiscovery.__aexit__``; exit + # cleanly so the consumer's ``async for`` ends. + break + yield event + finally: + await self._owner._unregister_subscriber(self) + + +class _RaisingPublisher: + """Guard-rail publisher returned by ``EcsDiscovery.publisher``. + + ECS owns worker lifecycle, so nothing in cfdb has cause to call + ``publish``; raising here surfaces the misuse loudly rather than + silently no-op'ing. + """ + + async def publish( + self, type: DiscoveryEventType, metadata: WorkerMetadata + ) -> None: + raise RuntimeError( + "EcsDiscovery is read-only — workers register implicitly via ECS" + ) + + +def _extract_eni_ip(task: dict[str, Any]) -> Optional[str]: + """Return the awsvpc private IPv4 from an ECS task description. + + ECS attaches one ENI per Fargate awsvpc task; its ``details`` list + carries a ``privateIPv4Address`` entry. Older API responses use + ``networkInterfaces`` on each container instead — we honor either. + When both forms are populated the ``attachments`` value wins; for + Fargate awsvpc both report the same IP, so the precedence is + moot in practice. + """ + for attachment in task.get("attachments") or []: + if attachment.get("type") not in (None, "ElasticNetworkInterface"): + continue + for detail in attachment.get("details") or []: + if detail.get("name") == "privateIPv4Address" and detail.get("value"): + return detail["value"] + for container in task.get("containers") or []: + for nic in container.get("networkInterfaces") or []: + ipv4 = nic.get("privateIpv4Address") + if ipv4: + return ipv4 + return None + + +def _diff( + cached: dict[str, WorkerMetadata], + discovered: dict[str, WorkerMetadata], +) -> Iterable[DiscoveryEvent]: + """Yield Wool events describing the cached→discovered transition.""" + for uid, metadata in discovered.items(): + if uid not in cached: + yield DiscoveryEvent("worker-added", metadata=metadata) + elif cached[uid].address != metadata.address: + yield DiscoveryEvent("worker-updated", metadata=metadata) + for uid, metadata in cached.items(): + if uid not in discovered: + yield DiscoveryEvent("worker-dropped", metadata=metadata) diff --git a/src/cfdb/workflows/executor.py b/src/cfdb/workflows/executor.py index 974facb..e4bf907 100644 --- a/src/cfdb/workflows/executor.py +++ b/src/cfdb/workflows/executor.py @@ -42,7 +42,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import wool @@ -64,6 +64,10 @@ from cfdb.workflows.models import ArtifactKind, JobRecord, JobStatus from cfdb.workflows.processors.base import Processor from cfdb.workflows.processors.registry import ProcessorRegistry +from cfdb.workflows.provisioner import RetryableProvisionerError + +if TYPE_CHECKING: + from cfdb.workflows.provisioner import EcsProvisioner logger = logging.getLogger(__name__) @@ -72,10 +76,11 @@ #: change to the workflow orchestration itself). PIPELINE_VERSION = 1 -#: Default job-runtime cap. 20 minutes covers a sort+index on a multi-GB -#: BAM comfortably; exceptional files that need longer should be -#: addressed individually. Sourced from the env var so deployments can -#: override without a code change. +#: Default job-runtime cap. 4 hours covers multi-hour preprocessing runs +#: (e.g., a ``samtools sort`` on a multi-GB BAM followed by ``samtools +#: index``); exceptional files that need longer should be addressed +#: individually. Sourced from the env var so deployments can override +#: without a code change — fixture-bound dev setups should lower it. DEFAULT_WORKFLOW_DURATION_CAP_SECONDS = WORKFLOW_DURATION_CAP_S #: Total wall-clock budget waiting for a leased worker to surface during @@ -223,7 +228,34 @@ class instance, ``file_meta`` is a plain dict (B1 strips Mongo's class WoolExecutor(JobExecutor): - """Executor backed by a ``wool.WorkerPool`` and a Mongo jobs collection.""" + """Executor backed by a ``wool.WorkerPool`` and a Mongo jobs collection. + + When configured with an :class:`EcsProvisioner`, the executor also + issues a ``RunTask`` on each fresh claim before opening the routine + stream, so a Fargate worker boots in parallel with the dispatch + retry loop. Without a provisioner (PoC dev profile) the executor + relies on workers already published into the discovery namespace. + + A :class:`RetryableProvisionerError` from the provisioner surfaces + as a terminal ``FAILED`` job with a ``capacity:``-prefixed error + string; any other provisioner failure surfaces with a + ``provisioner:`` prefix. Clients parse the prefix to decide + whether to resubmit. + + Args: + db: Motor database handle holding the ``jobs`` collection. + cache: Cache backend for processor artifacts. + cache_root: Directory used by ``LocalFsCache``; ignored for S3. + registry: Processor registry mapping artifact kinds to runners. + workdir_root: Parent directory under which per-job workdirs land. + pipeline_version: Embedded in workflow keys; bump to invalidate + every in-flight workflow. + workflow_duration_cap_seconds: Per-workflow wall-clock cap. + Defaults to ``DEFAULT_WORKFLOW_DURATION_CAP_SECONDS``. + provisioner: Optional :class:`EcsProvisioner`. When ``None`` + (PoC profile), no ``RunTask`` is issued and dispatch + relies on pre-existing workers. + """ def __init__( self, @@ -235,6 +267,7 @@ def __init__( workdir_root: Path, pipeline_version: int = PIPELINE_VERSION, workflow_duration_cap_seconds: int = DEFAULT_WORKFLOW_DURATION_CAP_SECONDS, + provisioner: Optional[EcsProvisioner] = None, ) -> None: self._db = db self._cache = cache @@ -244,6 +277,12 @@ def __init__( self._workdir_root.mkdir(parents=True, exist_ok=True) self._pipeline_version = pipeline_version self._workflow_duration_cap_seconds = workflow_duration_cap_seconds + #: External worker provisioner. When set, ``_run_workflow`` issues + #: a ``RunTask``-equivalent request to it on a fresh claim so a + #: Fargate worker boots in parallel with the dispatch retry loop. + #: When unset (PoC dev profile), the executor relies entirely on + #: workers already published into the discovery namespace. + self._provisioner = provisioner self._pending_tasks: set[asyncio.Task] = set() #: Finalize tasks (release_workflow + workdir cleanup) created by #: _run_workflow's `finally`. Tracked separately so drain() can @@ -287,7 +326,9 @@ async def ensure_workflow( ) if fresh: - task = asyncio.create_task(self._run_workflow(record, processor, file_meta)) + task = asyncio.create_task( + self._run_workflow(record, processor, file_meta, wf_key) + ) self._pending_tasks.add(task) task.add_done_callback(self._pending_tasks.discard) task.add_done_callback(_log_unexpected_exception) @@ -355,6 +396,7 @@ async def _run_workflow( record: JobRecord, processor: Processor, file_meta: dict[str, Any], + workflow_key: str, ) -> None: """Background coroutine: mark running, consume the routine's event stream, and release a terminal status. @@ -389,6 +431,40 @@ async def _run_workflow( final_error = f"mark_running failed: {exc}" return + # Request a worker via the external provisioner (e.g. ECS + # ``RunTask``) before opening the routine stream. The + # provisioner dedup-keys on the workflow mutex so two + # concurrent fresh claims for the same source file + # share one ``RunTask`` and one worker. A capacity / + # ENI / throttling failure surfaces as + # ``RetryableProvisionerError`` and is routed to a + # terminal ``FAILED`` via the same error path as a + # stream-open failure below. + if self._provisioner is not None: + try: + await self._provisioner.request(dedup_key=workflow_key) + except asyncio.CancelledError: + final_error = "Workflow cancelled (worker shutdown)" + raise + except RetryableProvisionerError as exc: + # Capacity / ENI exhaustion / throttling / boto + # transport transients. ``capacity:`` is the + # stable on-wire prefix clients parse to decide + # whether to resubmit the job. + logger.warning( + "Provisioner reported retryable capacity error for %s: %s", + record.job_id, + exc, + ) + final_error = f"capacity: {exc}" + return + except Exception as exc: + logger.exception( + "Provisioner request failed for %s", record.job_id + ) + final_error = f"provisioner: {exc}" + return + try: stream = await self._open_stream_with_retry( processor, file_meta, workdir diff --git a/src/cfdb/workflows/provisioner.py b/src/cfdb/workflows/provisioner.py new file mode 100644 index 0000000..a71ec58 --- /dev/null +++ b/src/cfdb/workflows/provisioner.py @@ -0,0 +1,387 @@ +"""ECS on-demand worker provisioning via the ``RunTask`` API. + +``EcsProvisioner`` wraps a single boto3 ``RunTask`` call with the cluster, +task-definition family, and awsvpc networking configuration the workflow +subsystem needs. The same code targets LocalStack in development and real +AWS in production — only ``AWS_ENDPOINT_URL`` differs at the boto3 client. + +Concurrent ``request`` calls sharing the same ``dedup_key`` (typically +the workflow key) attach to a single in-flight ``RunTask`` task and +observe the same outcome. ``asyncio.shield`` insulates each caller's +cancellation from the others, so a request abandoned by one caller does +not poison the result for the rest. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Iterable +from typing import Any, Literal, Optional, get_args + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError + +logger = logging.getLogger(__name__) + + +#: Acceptable values for ``RunTask`` ``assignPublicIp`` — ECS rejects +#: anything else. The runtime-validation set is derived from the +#: ``Literal`` so the two definitions can't drift. +AssignPublicIp = Literal["ENABLED", "DISABLED"] +_ASSIGN_PUBLIC_IP_VALUES = frozenset(get_args(AssignPublicIp)) + + +class RetryableProvisionerError(RuntimeError): + """Provisioner failure the caller should resubmit later. + + Covers both ECS-side capacity / ENI exhaustion and transport-level + transients (connection timeouts, endpoint unavailability, + throttling). The executor's response is identical for both — mark + the workflow ``FAILED`` with a retryable error string — so they + share one exception type. + """ + + +class EcsProvisioner: + """Launch ephemeral worker tasks on ECS Fargate. + + Args: + cluster: ECS cluster name or ARN. + task_definition: Task definition family (or family:revision) + for the worker container. + subnets: Awsvpc subnet IDs to place worker ENIs into. + security_groups: Awsvpc security group IDs to attach. + assign_public_ip: ``"ENABLED"`` or ``"DISABLED"`` — whether the + ENI gets a public IP. Production should usually leave this + disabled and rely on VPC endpoints; LocalStack accepts + either value. + client: Optional pre-built boto3 ``ecs`` client. When omitted, + one is constructed via :func:`build_ecs_client` with the + ``endpoint_url`` / ``region_name`` kwargs threaded through. + endpoint_url: Boto3 ``endpoint_url``. Passed to + :func:`build_ecs_client` when ``client`` is omitted. The + lifespan plumbs :data:`cfdb.api.AWS_ENDPOINT_URL` here. + region_name: Boto3 ``region_name``. Plumbed analogously. + max_in_flight: Soft cap on concurrent ``RunTask`` calls. ECS's + ``RunTask`` API is rate-limited to ~20 req/s per account; + this guard keeps us well under it. + """ + + def __init__( + self, + *, + cluster: str, + task_definition: str, + subnets: Iterable[str], + security_groups: Iterable[str] = (), + assign_public_ip: AssignPublicIp = "DISABLED", + client: Optional[Any] = None, + endpoint_url: Optional[str] = None, + region_name: Optional[str] = None, + max_in_flight: int = 16, + ) -> None: + if not cluster: + raise ValueError("EcsProvisioner requires a cluster name") + if not task_definition: + raise ValueError("EcsProvisioner requires a task_definition") + subnet_list = list(subnets) + if not subnet_list: + raise ValueError("EcsProvisioner requires at least one subnet") + if assign_public_ip not in _ASSIGN_PUBLIC_IP_VALUES: + raise ValueError( + f"assign_public_ip must be one of {sorted(_ASSIGN_PUBLIC_IP_VALUES)}; " + f"got {assign_public_ip!r}" + ) + + self._cluster = cluster + self._task_definition = task_definition + self._subnets = subnet_list + self._security_groups = list(security_groups) + self._assign_public_ip = assign_public_ip + self._client = ( + client + if client is not None + else build_ecs_client(endpoint_url=endpoint_url, region_name=region_name) + ) + self._semaphore = asyncio.Semaphore(max_in_flight) + # Concurrent ``request`` calls sharing a key attach to the + # same in-flight Task. ``_run_task_owned``'s ``finally`` block + # always pops the entry on completion (success or failure) so + # a fresh request after the previous one finishes spawns a new + # launch. + self._in_flight: dict[str, asyncio.Task[list[str]]] = {} + self._in_flight_lock = asyncio.Lock() + + async def request(self, *, dedup_key: str) -> list[str]: + """Launch a worker task, returning its ARN(s). + + Concurrent callers sharing ``dedup_key`` share one ``RunTask``: + only the first launches; the rest await the same result. + + Args: + dedup_key: Typically the workflow mutex key. Two callers + holding the same workflow-level mutex should not + independently launch two workers. + + Returns: + List of task ARNs corresponding to the launched worker(s). + + Raises: + RetryableProvisionerError: Capacity / ENI exhaustion, + throttling, or connection-level transients. Callers + should mark the workflow as failed with a retryable + status so the client can resubmit. + """ + if not dedup_key: + raise ValueError("EcsProvisioner.request requires a non-empty dedup_key") + + async with self._in_flight_lock: + existing = self._in_flight.get(dedup_key) + # Self-heal: if the cached task already finished and the + # finally-block release was skipped (CancelledError during + # event-loop teardown could re-raise out of the awaited + # shielded pop before the inner coroutine ran), treat the + # stale entry as absent and launch a fresh task. This + # collapses every "the pop never ran" cancellation race + # into one self-healing read at registration time. + if existing is not None and existing.done(): + existing = None + self._in_flight.pop(dedup_key, None) + if existing is None: + existing = asyncio.create_task( + self._run_task_owned(dedup_key), + name=f"ecs-runtask-{dedup_key}", + ) + # Retrieve the eventual exception/result so a Task + # abandoned by every caller (universal cancellation) + # doesn't trigger asyncio's "Task exception was never + # retrieved" warning when it's garbage-collected. + existing.add_done_callback(_consume_task_outcome) + self._in_flight[dedup_key] = existing + + # Shield protects concurrent callers sharing this dedup_key + # from each other's cancellation: if one caller is cancelled + # mid-await, only that caller sees CancelledError; the + # underlying task continues and other callers still observe + # the real result. + return await asyncio.shield(existing) + + async def _run_task_owned(self, dedup_key: str) -> list[str]: + """Body of the dedup-protected RunTask call. + + Split out from ``request`` so the dedup-registration lock is + not held across the boto3 thread-pool round-trip. The + ``finally`` always pops the in-flight slot so the next + ``request`` for the same key can launch a fresh worker. The + pop is shielded so a CancelledError delivered to the task + itself (event-loop teardown, explicit cancel) cannot skip the + cleanup and leak a stale dedup entry pointing at a cancelled + task. + """ + + async def _release_dedup_slot() -> None: + async with self._in_flight_lock: + self._in_flight.pop(dedup_key, None) + + try: + async with self._semaphore: + return await self._run_task() + finally: + await asyncio.shield(_release_dedup_slot()) + + async def _run_task(self) -> list[str]: + """Single ``RunTask`` invocation translated to a list of task ARNs. + + ``count`` is hardcoded to 1 because the dedup contract is one + worker per workflow key: two concurrent ``request(dedup_key=K)`` + calls share the same task, and once that task finishes a fresh + request for the same key launches a fresh worker. The + failure-vs-success branching below relies on this — any entry + in ``response["failures"]`` means the launch did not happen, so + the call raises rather than returning a mix. + """ + kwargs: dict[str, Any] = { + "cluster": self._cluster, + "taskDefinition": self._task_definition, + "launchType": "FARGATE", + "count": 1, + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": list(self._subnets), + "securityGroups": list(self._security_groups), + "assignPublicIp": self._assign_public_ip, + } + }, + } + + # ``ClientError`` covers HTTP-level ECS errors with structured + # response codes; ``BotoCoreError`` covers transport / cred + # failures with no ``.response``. Both should be classified + # together so transient connection issues surface as retryable. + try: + response = await asyncio.to_thread(self._client.run_task, **kwargs) + except (ClientError, BotoCoreError) as exc: + if _is_retryable_error(exc): + raise RetryableProvisionerError( + f"{type(exc).__name__}: {exc}" + ) from exc + raise + + failures = response.get("failures") or [] + tasks = response.get("tasks") or [] + arns = [t["taskArn"] for t in tasks if t.get("taskArn")] + + if failures and not arns: + # No ARN to fall back on — the launch did not happen. + reasons = ", ".join(f.get("reason", "?") for f in failures) + if any(_is_retryable_failure(f) for f in failures): + raise RetryableProvisionerError( + f"ECS RunTask failures: {reasons}" + ) + raise RuntimeError(f"ECS RunTask failures: {reasons}") + + if failures and arns: + # Partial-success edge: ECS occasionally surfaces a + # secondary placement warning alongside a successfully + # launched task. Log and keep the ARN rather than + # discarding a worker that's already running (and + # billing) to chase the warning. + reasons = ", ".join(f.get("reason", "?") for f in failures) + logger.warning( + "ECS RunTask succeeded with %d task(s) but reported failures: %s", + len(arns), + reasons, + ) + + if not arns: + # ECS returned neither a launched task nor a failure entry. + # Treat as a retryable transient — the caller has nothing + # to dispatch to and the executor's failed-with-retryable + # path is the right response. + raise RetryableProvisionerError( + "ECS RunTask returned no tasks and no failures" + ) + logger.info( + "ECS RunTask launched %d task(s) on cluster=%s family=%s", + len(arns), + self._cluster, + self._task_definition, + ) + return arns + + async def aclose(self) -> None: + """Cancel every in-flight ``RunTask`` and clear the dedup map. + + Called from the API lifespan's ``finally`` so a shutdown while + a ``RunTask`` round-trip is mid-flight doesn't leave the task + unrequested-but-billed. ``gather(return_exceptions=True)`` + absorbs the ``CancelledError`` and any ``RetryableProvisionerError`` + already in flight; the done-callback handles "Task exception + never retrieved" warnings for the cases we don't await here. + """ + async with self._in_flight_lock: + tasks = list(self._in_flight.values()) + self._in_flight.clear() + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +def build_ecs_client( + *, endpoint_url: Optional[str] = None, region_name: Optional[str] = None +) -> Any: + """Construct a boto3 ``ecs`` client with explicit endpoint/region. + + The caller (typically :class:`EcsProvisioner`, :class:`EcsDiscovery`, + or the API lifespan) is the single source of truth for + ``endpoint_url`` and ``region_name`` — we no longer reach into + :mod:`cfdb.api` for fallback values. Pass + :data:`cfdb.api.AWS_ENDPOINT_URL` / :data:`cfdb.api.AWS_REGION` + from the lifespan; leave both ``None`` to let boto3's default + session resolver chain pick them up from the environment. + """ + return boto3.client( + "ecs", + endpoint_url=endpoint_url, + region_name=region_name, + ) + + +# Codes ECS returns as ``ClientError.response["Error"]["Code"]`` for a +# ``RunTask`` failure that a retry can plausibly fix. Capacity-style +# placement reasons (``RESOURCE:ENI`` / ``RESOURCE:CPU`` / +# ``RESOURCE:MEMORY`` / ``AWS.ECS.PlacementError``) are NOT in this +# set because ECS surfaces them via ``failures[].reason``, not as a +# top-level error code; ``_RETRYABLE_REASON_TOKENS`` catches them on +# the failure-payload path. +_RETRYABLE_ERROR_CODES = frozenset( + { + "Capacity", + "CapacityProviderException", + "ClusterCapacityProviderException", + "ThrottlingException", + "Throttling", + "RequestLimitExceeded", + "ServerException", + "ServiceUnavailableException", + "ServiceUnavailable", + } +) +# Substring match (``token in reason``). ``THROTTL`` is deliberately +# truncated so it covers any throttling-derived reason — ``THROTTLED``, +# ``THROTTLING``, ``THROTTLINGEXCEPTION`` — without enumerating every +# variant AWS might emit. +_RETRYABLE_REASON_TOKENS = ("CAPACITY", "RESOURCE:", "THROTTL") + + +def _is_retryable_error(exc: BaseException) -> bool: + """Return True when an ECS exception is a retryable transient. + + Handles both ``ClientError`` (with a structured response dict) and + ``BotoCoreError`` subclasses (transport-level failures without a + response) — connection timeouts, endpoint unavailability, and + other transport transients are exactly the cases the caller should + resubmit. + """ + response = getattr(exc, "response", None) + if isinstance(response, dict): + code = (response.get("Error") or {}).get("Code") + if code in _RETRYABLE_ERROR_CODES: + return True + # BotoCoreError subclasses (EndpointConnectionError, + # ConnectTimeoutError, ReadTimeoutError, HTTPClientError, etc.) + # carry no .response; classify the whole family as retryable. + return isinstance(exc, BotoCoreError) + + +def _is_retryable_failure(failure: dict[str, Any]) -> bool: + """Return True when a ``RunTask`` ``failures`` entry is retryable.""" + reason = (failure.get("reason") or "").upper() + return any(token in reason for token in _RETRYABLE_REASON_TOKENS) + + +def _consume_task_outcome(task: asyncio.Task[Any]) -> None: + """Retrieve a Task's outcome so asyncio doesn't log "never retrieved". + + Attached as a done-callback on the in-flight provisioner Task so + that universal-cancellation (every caller cancels before the Task + completes) doesn't surface as a noisy unretrieved-exception + warning at GC time. ``RetryableProvisionerError`` is a routine + outcome under capacity / throttling pressure, so abandoned ones + are logged at DEBUG; unexpected exception types stay at WARNING. + """ + if task.cancelled(): + return + exc = task.exception() + if exc is None: + return + if isinstance(exc, RetryableProvisionerError): + logger.debug( + "ECS RunTask completed with abandoned retryable error: %r", exc + ) + else: + logger.warning( + "ECS RunTask completed with abandoned exception: %r", exc + ) diff --git a/src/cfdb/workflows/worker_main.py b/src/cfdb/workflows/worker_main.py new file mode 100644 index 0000000..4fb58f2 --- /dev/null +++ b/src/cfdb/workflows/worker_main.py @@ -0,0 +1,292 @@ +"""ECS Fargate worker entrypoint. + +This module is the ``CMD`` for the worker container image. It boots a +``wool.LocalWorker`` on a known port, exposes a tiny HTTP health endpoint +the ECS health check probes, handles SIGTERM cleanly so ``ecs.stop_task`` +cycles drain in-flight work, and self-terminates after a configurable +maximum lifetime so workers don't accumulate when the dispatch rate falls. + +No discovery registration code lives here. ``EcsDiscovery`` polls ECS's +own task state to surface running workers; the worker only needs to +bind its gRPC port and respond ``200 OK`` to the health probe. +""" + +from __future__ import annotations + +import asyncio +import logging +import signal +from collections.abc import Callable +from typing import TYPE_CHECKING, Optional + +import click +import wool + +from cfdb.workflows.constants import DEFAULT_WORKER_PORT + +if TYPE_CHECKING: + from aiohttp import web + +logger = logging.getLogger(__name__) + +__all__ = ["main", "serve"] + +#: Default health probe HTTP port — distinct from the gRPC port so +#: ``healthCheck`` can ``curl`` it without speaking gRPC. +DEFAULT_HEALTH_PORT = 8080 + +#: Default maximum wall-clock lifetime of a worker process. Wool exposes +#: no per-job activity hook today, so the worker can't tell idle from +#: busy; this is a hard ceiling — ECS replaces the task after this long. +#: Sized one hour above :data:`cfdb.workflows.WORKFLOW_DURATION_CAP_S` +#: (default 4 h) so a worker started shortly before a long sort can +#: still outlive the job. Note: max-lifetime expiry exits without the +#: drain-grace window the SIGTERM path provides — operators that want +#: a cleaner handoff should rely on ECS rolling tasks via service +#: updates rather than waiting for max-lifetime to fire. +DEFAULT_MAX_LIFETIME_SECONDS = 5 * 60 * 60 + +#: How long to keep returning ``503`` on ``/health`` after SIGTERM, +#: giving ECS a chance to observe ``unhealthy`` and drain at the load +#: balancer before we tear the gRPC port down. ECS health checks +#: default to ~30 s interval × 3 unhealthy retries = ~90 s worst case +#: to mark the task unhealthy; we hold for 120 s by default to leave +#: real margin even when the task definition uses ECS-default health +#: check cadence. Operators tightening ``healthCheck.interval`` to +#: ~10-15 s can lower this further via ``WORKER_DRAIN_GRACE_SECONDS``. +DEFAULT_DRAIN_GRACE_SECONDS = 120.0 + +#: Cadence of the main loop's wakeup when waiting on ``stop_event``. +#: One second balances precision (max-lifetime checks fire within ~1 s +#: of the target) against CPU churn over a multi-hour worker lifetime +#: (~3600 wakeups/h vs. ~3.6 M at 1 ms). Sub-second granularity is +#: unnecessary because the upstream coarse-grained signals +#: (SIGTERM-on-stop_task, ~hours-scale max-lifetime) don't need it. +_STOP_POLL_INTERVAL_SECONDS = 1.0 + + +async def serve( + *, + worker_port: int = DEFAULT_WORKER_PORT, + health_port: int = DEFAULT_HEALTH_PORT, + max_lifetime_seconds: float = DEFAULT_MAX_LIFETIME_SECONDS, + drain_grace_seconds: float = DEFAULT_DRAIN_GRACE_SECONDS, +) -> int: + """Run the worker until SIGTERM or maximum lifetime elapses. + + Returns ``0`` on clean shutdown (SIGTERM, SIGINT, or max-lifetime). + Bind failures and other early-startup errors raise out — ``main`` + propagates them and the process exits with a Python traceback, + which surfaces the cause in container logs more clearly than a + silent non-zero status. + + A second SIGTERM during drain short-circuits the grace window + (operator impatience or ECS escalating before SIGKILL); the worker + stops immediately. + """ + stop_event = asyncio.Event() + force_stop_event = asyncio.Event() + loop = asyncio.get_running_loop() + started_at = loop.time() + + def _signal_handler() -> None: + if stop_event.is_set(): + logger.info("Second termination signal — exiting drain immediately") + force_stop_event.set() + else: + logger.info("Received termination signal — draining worker") + stop_event.set() + + def _signal_handler_threaded(*_: object) -> None: + # Fallback for platforms (notably Windows in CI) where + # ``loop.add_signal_handler`` raises ``NotImplementedError``. + # Defined once so the closure isn't rebuilt per signal. + loop.call_soon_threadsafe(_signal_handler) + + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + signal.signal(sig, _signal_handler_threaded) + + health_runner = await _start_health_server( + health_port, lambda: stop_event.is_set() + ) + + worker = wool.LocalWorker(host="0.0.0.0", port=worker_port) + await worker.start() + try: + logger.info( + "Wool worker listening on port %d (health on %d)", + worker_port, + health_port, + ) + while True: + # Check the self-termination path first. Max-lifetime is a + # local hard ceiling — nothing upstream (ECS, LB) is + # waiting on us to drain a multi-second grace window, but + # flipping ``/health`` to 503 here costs nothing and gives + # ``EcsDiscovery`` (and any layered LB) a chance to mark + # the task unhealthy before ``worker.stop()`` actually + # closes the gRPC port in ``finally``. Without this flip + # the health endpoint reports 200 right up to the moment + # the aiohttp runner is torn down. + if ( + max_lifetime_seconds > 0 + and (loop.time() - started_at) >= max_lifetime_seconds + ): + logger.info( + "Max lifetime (%.0fs) reached — exiting", + max_lifetime_seconds, + ) + stop_event.set() + break + if stop_event.is_set(): + # Signal-initiated shutdown. Hold the gRPC port open + # while ECS observes 503 on /health and stops routing + # new dispatches. A second signal flips + # force_stop_event and we exit the wait early. + if drain_grace_seconds > 0: + logger.info( + "Draining for up to %.0fs before exiting", + drain_grace_seconds, + ) + try: + await asyncio.wait_for( + force_stop_event.wait(), + timeout=drain_grace_seconds, + ) + except asyncio.TimeoutError: + pass + break + try: + await asyncio.wait_for( + stop_event.wait(), timeout=_STOP_POLL_INTERVAL_SECONDS + ) + except asyncio.TimeoutError: + continue + return 0 + finally: + try: + await worker.stop() + except Exception: + logger.exception("worker.stop() failed during shutdown") + await _shutdown_health_server(health_runner) + + +async def _start_health_server( + port: int, draining: Callable[[], bool] +) -> "web.AppRunner": + """Start a tiny HTTP server returning ``200 OK`` on ``/health``. + + The container's ECS ``healthCheck`` runs ``curl`` against this + endpoint, so the response shape doesn't matter — only the status + code does. While ``draining`` returns True (i.e. we're shutting + down), the endpoint returns ``503`` so ECS can mark the task + unhealthy and the load balancer / discovery can drain it before + ``ecs.stop_task`` actually kills the gRPC port. + + Cleans up the partial runner on bind failure so the caller's + finally doesn't have a half-initialized object to deal with. + """ + from aiohttp import web + + async def _health(_: web.Request) -> web.Response: + if draining(): + return web.Response(status=503, text="draining") + return web.Response(status=200, text="ok") + + app = web.Application() + app.router.add_get("/health", _health) + runner = web.AppRunner(app) + await runner.setup() + try: + site = web.TCPSite(runner, host="0.0.0.0", port=port) + await site.start() + except Exception: + await runner.cleanup() + raise + return runner + + +async def _shutdown_health_server(runner: Optional["web.AppRunner"]) -> None: + """Tear down the aiohttp ``AppRunner`` started by ``_start_health_server``. + + Tolerates ``runner is None`` so the caller's ``finally`` can run + even when the health server failed to start in the first place. + Cleanup exceptions are logged but swallowed: the worker is already + on its way out, and surfacing a teardown error would mask the + upstream cause (whatever triggered the shutdown). + """ + if runner is None: + return + try: + await runner.cleanup() + except Exception: + logger.exception("health server cleanup failed") + + +@click.command() +@click.option( + "--worker-port", + type=click.IntRange(1, 65535), + envvar="CFDB_WORKER_GRPC_PORT", + default=DEFAULT_WORKER_PORT, + show_default=True, + help="gRPC port the wool worker binds.", +) +@click.option( + "--health-port", + type=click.IntRange(1, 65535), + envvar="CFDB_WORKER_HEALTH_PORT", + default=DEFAULT_HEALTH_PORT, + show_default=True, + help="HTTP port the ECS health-check endpoint binds.", +) +@click.option( + "--max-lifetime-seconds", + type=click.FloatRange(min=0), + envvar="CFDB_WORKER_MAX_LIFETIME_SECONDS", + default=DEFAULT_MAX_LIFETIME_SECONDS, + show_default=True, + help="Hard ceiling on worker uptime in seconds; 0 disables.", +) +@click.option( + "--drain-grace-seconds", + type=click.FloatRange(min=0), + envvar="CFDB_WORKER_DRAIN_GRACE_SECONDS", + default=DEFAULT_DRAIN_GRACE_SECONDS, + show_default=True, + help=( + "Seconds to keep returning 503 on /health after SIGTERM before " + "tearing down the gRPC port. A second SIGTERM short-circuits." + ), +) +def main( + worker_port: int, + health_port: int, + max_lifetime_seconds: float, + drain_grace_seconds: float, +) -> None: + """ECS Fargate worker entrypoint — invoked by the container CMD. + + Boots a wool gRPC worker, exposes /health for ECS to probe, and + self-terminates after the max-lifetime ceiling. SIGTERM begins a + drain grace window during which /health returns 503 so the load + balancer can drop the worker before the gRPC port closes. + """ + logging.basicConfig(level=logging.INFO) + raise SystemExit( + asyncio.run( + serve( + worker_port=worker_port, + health_port=health_port, + max_lifetime_seconds=max_lifetime_seconds, + drain_grace_seconds=drain_grace_seconds, + ) + ) + ) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/tests/test_api/__init__.py b/tests/test_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api/test_env_parsers.py b/tests/test_api/test_env_parsers.py new file mode 100644 index 0000000..c90017a --- /dev/null +++ b/tests/test_api/test_env_parsers.py @@ -0,0 +1,75 @@ +"""Tests for the env-var parsers in :mod:`cfdb.api`.""" + +from __future__ import annotations + +import pytest + +from cfdb.api import _parse_assign_public_ip + + +class TestParseAssignPublicIp: + def test__parse_assign_public_ip_with_unset_var(self, monkeypatch): + """Test that an unset env var returns the supplied default. + + Given: + No environment variable set for the given name. + When: + ``_parse_assign_public_ip`` is called with a default value. + Then: + It should return the default rather than raising so the PoC + profile (no ECS env) still imports cleanly. + """ + # Arrange + monkeypatch.delenv("CFDB_TEST_ASSIGN_PUBLIC_IP", raising=False) + + # Act + result = _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + ) + + # Assert + assert result == "DISABLED" + + def test__parse_assign_public_ip_with_valid_value(self, monkeypatch): + """Test that an in-set value is returned as-is. + + Given: + ``CFDB_TEST_ASSIGN_PUBLIC_IP=ENABLED``. + When: + ``_parse_assign_public_ip`` is called. + Then: + It should return ``"ENABLED"`` unchanged so the lifespan can + hand it to ``EcsProvisioner`` verbatim. + """ + # Arrange + monkeypatch.setenv("CFDB_TEST_ASSIGN_PUBLIC_IP", "ENABLED") + + # Act + result = _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + ) + + # Assert + assert result == "ENABLED" + + def test__parse_assign_public_ip_with_invalid_value(self, monkeypatch): + """Test that an out-of-set value surfaces as an ImportError. + + Given: + ``CFDB_TEST_ASSIGN_PUBLIC_IP`` set to an invalid value + (lowercase, typo, anything outside {ENABLED, DISABLED}). + When: + ``_parse_assign_public_ip`` is called. + Then: + It should raise ``ImportError`` so the misconfiguration + surfaces at module load rather than crashing the lifespan + mid-bootstrap. + """ + # Arrange + monkeypatch.setenv("CFDB_TEST_ASSIGN_PUBLIC_IP", "enabled") + + # Act & assert + with pytest.raises(ImportError, match="must be one of"): + _parse_assign_public_ip( + "CFDB_TEST_ASSIGN_PUBLIC_IP", default="DISABLED" + ) diff --git a/tests/test_workflows/test_cache_s3.py b/tests/test_workflows/test_cache_s3.py new file mode 100644 index 0000000..aff997e --- /dev/null +++ b/tests/test_workflows/test_cache_s3.py @@ -0,0 +1,240 @@ +"""Tests for the moto-backed S3Cache implementation.""" + +from __future__ import annotations + +from pathlib import Path + +import boto3 +import pytest +from moto import mock_aws + +from cfdb.workflows.cache import CacheEntry, S3Cache + + +_BUCKET = "cfdb-test-cache" + + +def _write(path: Path, data: bytes) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + + +async def _collect(stream) -> bytes: + chunks: list[bytes] = [] + async for chunk in stream: + chunks.append(chunk) + return b"".join(chunks) + + +@pytest.fixture() +def s3_client(): + """Return a moto-backed boto3 S3 client with one created bucket.""" + with mock_aws(): + client = boto3.client("s3", region_name="us-east-1") + client.create_bucket(Bucket=_BUCKET) + yield client + + +@pytest.fixture() +def cache(s3_client) -> S3Cache: + """Return an S3Cache wired up to the moto-backed client.""" + return S3Cache(bucket=_BUCKET, client=s3_client) + + +class TestS3Cache: + def test___init___with_empty_bucket_name(self, s3_client): + """Test that S3Cache rejects an empty bucket name. + + Given: + A moto-backed S3 client and an empty bucket name. + When: + S3Cache is constructed. + Then: + It should raise ValueError so misconfigurations fail fast. + """ + # Act & assert + with pytest.raises(ValueError, match="bucket"): + S3Cache(bucket="", client=s3_client) + + @pytest.mark.asyncio + async def test_head_with_absent_key(self, cache): + """Test that head reports a cache miss as None. + + Given: + An empty S3Cache. + When: + head is awaited for an unknown key. + Then: + It should return None so the router can dispatch a workflow. + """ + # Act + entry = await cache.head("encode/x/data/aa-v0") + + # Assert + assert entry is None + + @pytest.mark.asyncio + async def test_put_then_head_with_known_payload(self, cache, tmp_path): + """Test that put commits the artifact and head reports its size. + + Given: + An S3Cache and a source file with known contents. + When: + put commits the artifact and head is then awaited. + Then: + head should return a CacheEntry carrying the exact byte size. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"hello world") + + # Act + await cache.put("encode/x/data/aa-v0", source) + entry = await cache.head("encode/x/data/aa-v0") + + # Assert + assert entry == CacheEntry(key="encode/x/data/aa-v0", size=11) + + @pytest.mark.asyncio + async def test_get_with_full_object(self, cache, tmp_path): + """Test that get streams the full artifact without a byte range. + + Given: + An S3Cache containing a multi-chunk artifact. + When: + get is iterated without a byte_range argument. + Then: + It should yield the complete bytes. + """ + # Arrange + source = tmp_path / "src" + payload = b"0123456789" * 20_000 + _write(source, payload) + await cache.put("encode/x/data/aa-v0", source) + + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0")) + + # Assert + assert collected == payload + + @pytest.mark.asyncio + async def test_get_with_inclusive_byte_range(self, cache, tmp_path): + """Test that get forwards an inclusive byte range to S3. + + Given: + An S3Cache containing a known artifact. + When: + get is iterated with byte_range=(5, 9). + Then: + It should yield exactly bytes 5..9 inclusive (5 bytes total). + """ + # Arrange + source = tmp_path / "src" + _write(source, b"0123456789ABCDEF") + await cache.put("encode/x/data/aa-v0", source) + + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0", (5, 9))) + + # Assert + assert collected == b"56789" + + @pytest.mark.asyncio + async def test_get_with_absent_key(self, cache): + """Test that get yields nothing for an absent key. + + Given: + An empty S3Cache. + When: + get is iterated for a missing key. + Then: + It should yield no chunks rather than raise. + """ + # Act + collected = await _collect(cache.get("encode/x/data/aa-v0")) + + # Assert + assert collected == b"" + + @pytest.mark.asyncio + async def test_delete_with_present_key(self, cache, tmp_path): + """Test that delete reports True and removes the artifact. + + Given: + An S3Cache containing one entry. + When: + delete is awaited for that key. + Then: + It should return True and a subsequent head should miss. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"x") + await cache.put("encode/x/data/aa-v0", source) + + # Act + deleted = await cache.delete("encode/x/data/aa-v0") + + # Assert + assert deleted is True + assert await cache.head("encode/x/data/aa-v0") is None + + @pytest.mark.asyncio + async def test_delete_with_absent_key(self, cache): + """Test that delete on a missing key is idempotent. + + Given: + An empty S3Cache. + When: + delete is awaited for an unknown key. + Then: + It should return False rather than raise. + """ + # Act & assert + assert await cache.delete("encode/x/data/aa-v0") is False + + @pytest.mark.asyncio + async def test_put_with_traversal_segment_key(self, cache, tmp_path): + """Test that put refuses keys containing path-traversal segments. + + Given: + An S3Cache and a key with a ``..`` segment. + When: + put is awaited with that key. + Then: + It should raise ValueError rather than write under a parent prefix. + """ + # Arrange + source = tmp_path / "src" + _write(source, b"x") + + # Act & assert + with pytest.raises(ValueError): + await cache.put("../oops", source) + + @pytest.mark.asyncio + async def test_put_then_head_with_configured_prefix(self, s3_client, tmp_path): + """Test that the configured prefix is applied to every operation. + + Given: + An S3Cache configured with a non-empty prefix. + When: + put writes a key and head reads it back. + Then: + The object should land under ``/`` and head should + report its size correctly. + """ + # Arrange + cache = S3Cache(bucket=_BUCKET, prefix="env/dev", client=s3_client) + source = tmp_path / "src" + _write(source, b"abc") + + # Act + await cache.put("encode/x/data/aa-v0", source) + entry = await cache.head("encode/x/data/aa-v0") + + # Assert — moto stores under the prefixed key, not the raw key + assert entry == CacheEntry(key="encode/x/data/aa-v0", size=3) + listing = s3_client.list_objects_v2(Bucket=_BUCKET).get("Contents", []) + assert {item["Key"] for item in listing} == {"env/dev/encode/x/data/aa-v0"} diff --git a/tests/test_workflows/test_discovery.py b/tests/test_workflows/test_discovery.py new file mode 100644 index 0000000..471dbe9 --- /dev/null +++ b/tests/test_workflows/test_discovery.py @@ -0,0 +1,343 @@ +"""Tests for EcsDiscovery.""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import Any + +import pytest + +from cfdb.workflows.discovery import EcsDiscovery + + +def _task_arn(task_id: str | None = None) -> str: + """Build a stable ECS task ARN for tests.""" + return f"arn:aws:ecs:us-east-1:123:task/cluster/{task_id or uuid.uuid4().hex}" + + +def _running_task( + task_id: str, + *, + ip: str = "10.0.0.5", + health: str = "HEALTHY", + status: str = "RUNNING", +) -> dict[str, Any]: + """Construct a fake ECS DescribeTasks entry.""" + return { + "taskArn": _task_arn(task_id), + "lastStatus": status, + "healthStatus": health, + "attachments": [ + { + "type": "ElasticNetworkInterface", + "details": [ + {"name": "subnetId", "value": "subnet-1"}, + {"name": "privateIPv4Address", "value": ip}, + ], + } + ], + } + + +class _FakeEcsClient: + """Fake ECS client whose responses can be re-set per call.""" + + def __init__(self) -> None: + self.task_arns: list[str] = [] + self.tasks: list[dict[str, Any]] = [] + + def list_tasks(self, **_kwargs): + return {"taskArns": list(self.task_arns)} + + def describe_tasks(self, *, cluster: str, tasks: list[str]): + wanted = set(tasks) + return { + "tasks": [t for t in self.tasks if t["taskArn"] in wanted], + } + + +class TestEcsDiscovery: + def test___init___without_cluster(self): + """Test that EcsDiscovery rejects an empty cluster argument. + + Given: + An empty cluster name. + When: + EcsDiscovery is constructed. + Then: + It should raise ValueError so misconfigurations fail fast. + """ + # Act & assert + with pytest.raises(ValueError, match="cluster"): + EcsDiscovery( + cluster="", + task_definition_family="worker", + client=_FakeEcsClient(), + ) + + def test___init___without_task_definition_family(self): + """Test that EcsDiscovery rejects an empty task_definition_family. + + Given: + A cluster but no task_definition_family. + When: + EcsDiscovery is constructed. + Then: + It should raise ValueError to surface the misconfiguration. + """ + # Act & assert + with pytest.raises(ValueError, match="task_definition_family"): + EcsDiscovery( + cluster="c", + task_definition_family="", + client=_FakeEcsClient(), + ) + + @pytest.mark.asyncio + async def test_poll_once_with_initial_healthy_workers(self): + """Test that the first poll emits worker-added events. + + Given: + A fake ECS client returning one RUNNING + HEALTHY task. + When: + poll_once is awaited. + Then: + It should emit exactly one ``worker-added`` event whose metadata + carries the task's private IP and the configured worker port. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + # The fake's task ARN must align with the seed. + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + worker_port=4242, + ) + + # Act + events, resolved = await discovery.poll_once() + + # Assert + assert len(events) == 1 + assert events[0].type == "worker-added" + assert events[0].metadata.address == "10.0.0.5:4242" + assert len(resolved) == 1 + + @pytest.mark.asyncio + async def test_poll_once_with_unhealthy_task_filtered(self): + """Test that UNHEALTHY tasks never surface as workers. + + Given: + A fake ECS client whose task is RUNNING but UNHEALTHY. + When: + poll_once is awaited. + Then: + No events should be emitted and resolved should be empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id, health="UNHEALTHY")] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + + # Act + events, resolved = await discovery.poll_once() + + # Assert + assert events == [] + assert resolved == {} + + @pytest.mark.asyncio + async def test_poll_once_with_dropped_task_after_initial_seen(self): + """Test that a vanished task surfaces as worker-dropped. + + Given: + A fake ECS client that initially reports one task and then + reports an empty cluster on the next poll. + When: + poll_once is awaited twice. + Then: + The second poll should emit one worker-dropped event for the + previously-known task. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + await discovery.poll_once() + + # Act — second poll observes the cluster gone empty + client.task_arns = [] + client.tasks = [] + events, resolved = await discovery.poll_once() + + # Assert + assert len(events) == 1 + assert events[0].type == "worker-dropped" + assert resolved == {} + + @pytest.mark.asyncio + async def test_poll_once_with_idempotent_steady_state(self): + """Test that repeated polls of the same set emit no extra events. + + Given: + A fake ECS client that consistently returns the same task. + When: + poll_once is awaited twice. + Then: + The second poll should emit no events because the diff is empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + await discovery.poll_once() + + # Act + events, _ = await discovery.poll_once() + + # Assert + assert events == [] + + @pytest.mark.asyncio + async def test_subscribe_with_filter_excluding_worker(self): + """Test that a subscriber's filter suppresses non-matching events. + + Given: + A discovery instance with one healthy worker and a subscriber + whose filter rejects every metadata. + When: + poll_once is awaited. + Then: + The subscriber's queue should remain empty. + """ + # Arrange + client = _FakeEcsClient() + task_id = uuid.uuid4().hex + client.task_arns = [_task_arn(task_id)] + client.tasks = [_running_task(task_id)] + client.tasks[0]["taskArn"] = client.task_arns[0] + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + sub = discovery.subscribe(filter=lambda _meta: False) + + async def _drain(): + async for event in sub: # pragma: no cover — should never iterate + return event + return None + + # Act — start subscriber, then poll + consumer = asyncio.create_task(_drain()) + await asyncio.sleep(0) # let consumer register + await discovery.poll_once() + await asyncio.sleep(0.05) + + # Assert + consumer.cancel() + try: + await consumer + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_subscribe_with_two_iter_calls_raises_on_second(self): + """Test that double-iteration of one subscriber is refused. + + Given: + A subscriber whose async iterator has already been opened. + When: + A second async-for loop tries to drive the same subscriber. + Then: + It should raise RuntimeError on the second __anext__ rather + than silently double-register and duplicate event delivery. + """ + # Arrange + client = _FakeEcsClient() + discovery = EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) + sub = discovery.subscribe() + iter_one = sub.__aiter__() + iter_two = sub.__aiter__() + # Start iter_one so it flips _exhausted before iter_two runs. + first_consumer = asyncio.create_task(iter_one.__anext__()) + await asyncio.sleep(0) + + # Act & assert + with pytest.raises(RuntimeError, match="already iterated"): + await iter_two.__anext__() + + # Cleanup + first_consumer.cancel() + try: + await first_consumer + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_aexit_wakes_parked_consumer(self): + """Test that __aexit__ unblocks consumers parked on queue.get(). + + Given: + A discovery context with a subscriber consuming events. + When: + The discovery context exits while the consumer is parked. + Then: + The consumer's async-for loop ends cleanly within a short + timeout rather than blocking on a queue that nothing will + publish to again. + """ + # Arrange + client = _FakeEcsClient() + events_seen: list[Any] = [] + + async def _consume(sub): + async for event in sub: + events_seen.append(event) + + # Act + async with EcsDiscovery( + cluster="c", + task_definition_family="worker", + client=client, + ) as discovery: + consumer = asyncio.create_task(_consume(discovery.subscribe())) + # Let the consumer register and park on queue.get(). + await asyncio.sleep(0.05) + # Exiting the context should push the sentinel; the consumer + # task ends shortly after. + await asyncio.wait_for(consumer, timeout=1.0) + + # Assert — the consumer ended cleanly without exception. + assert consumer.done() and consumer.exception() is None diff --git a/tests/test_workflows/test_executor.py b/tests/test_workflows/test_executor.py index ab9e346..9c3e355 100644 --- a/tests/test_workflows/test_executor.py +++ b/tests/test_workflows/test_executor.py @@ -11,6 +11,7 @@ import wool from cfdb.workflows import executor as executor_module +from cfdb.workflows import keys as key_utils from cfdb.workflows.executor import ( PIPELINE_VERSION, ExecutorDraining, @@ -22,6 +23,7 @@ from cfdb.workflows.models import ACTIVE_STATUSES, ArtifactKind, JobStatus from cfdb.workflows.processors.base import Processor from cfdb.workflows.processors.registry import ProcessorRegistry +from cfdb.workflows.provisioner import EcsProvisioner, RetryableProvisionerError from tests.test_workflows import FIXTURE_MD5 #: Canonical artifact keys the stubs emit. Built from FIXTURE_MD5 so the @@ -1241,3 +1243,193 @@ async def test_ensure_workflow_should_persist_file_meta_snapshot_on_insert( final = await get_job(mock_db, record.job_id) assert final is not None assert final.file_meta_snapshot == meta + + +class TestWoolExecutorWithProvisioner: + @pytest.mark.asyncio + async def test_ensure_workflow_should_request_worker_when_provisioner_set( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that a fresh claim dispatches a provisioner request. + + Given: + A WoolExecutor wired with a stub EcsProvisioner. + When: + ``ensure_workflow`` is awaited on a previously-unseen file. + Then: + The provisioner's ``request`` should be awaited exactly once + with ``dedup_key`` set to the workflow mutex key for the file. + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.return_value = ["arn:fake:task/abc"] + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + meta = _file_meta() + dcc, local_id, md5 = extract_identity(meta) + expected_key = key_utils.workflow_key( + dcc=dcc, local_id=local_id, md5=md5, pipeline_version=PIPELINE_VERSION + ) + + # Act + record, _ = await executor.ensure_workflow(meta) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + provisioner.request.assert_awaited_once_with(dedup_key=expected_key) + + @pytest.mark.asyncio + async def test_ensure_workflow_should_skip_provisioner_when_attaching_to_existing_job( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that the provisioner is only invoked on fresh claims. + + Given: + An executor with a stub provisioner and a blocking processor + so the first claim stays active when the second call arrives. + When: + ``ensure_workflow`` is awaited twice for the same file_meta. + Then: + ``provisioner.request`` should be awaited exactly once — the + attach path must not double-spend a Fargate worker against + an already-claimed workflow. + """ + # Arrange + _install_jobs_index(mock_db) + release = asyncio.Event() + + class _BlockingProcessor(_StubProcessor): + async def run(self, file_meta, workdir, cache_root): + self.run_calls += 1 + await release.wait() + for kind, key in self.artifacts.items(): + yield {"event": "stage_complete", "kind": kind, "key": key} + yield {"event": "complete", "artifacts": dict(self.artifacts)} + + registry = ProcessorRegistry() + registry.register(_BlockingProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.return_value = ["arn:fake:task/abc"] + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record_a, fresh_a = await executor.ensure_workflow(_file_meta()) + record_b, fresh_b = await executor.ensure_workflow(_file_meta()) + release.set() + await _wait_for_terminal(mock_db, record_a.job_id) + + # Assert + assert fresh_a is True + assert fresh_b is False + assert record_a.job_id == record_b.job_id + provisioner.request.assert_awaited_once() + + @pytest.mark.asyncio + async def test_ensure_workflow_should_record_capacity_failure_when_provisioner_raises( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that a provisioner failure routes to a clean FAILED status. + + Given: + A provisioner whose ``request`` raises + ``RetryableProvisionerError`` (capacity / ENI exhaustion / + throttling). + When: + ``ensure_workflow`` is awaited. + Then: + The job should reach ``FAILED`` with the provisioner error + string captured on the record so the next client retry can + see why the workflow didn't run. + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.side_effect = RetryableProvisionerError( + "ClientError: capacity-unavailable" + ) + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record, _ = await executor.ensure_workflow(_file_meta()) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + final = await get_job(mock_db, record.job_id) + assert final is not None + assert final.status == JobStatus.FAILED + assert final.error is not None + assert final.error.startswith("capacity: "), ( + f"expected 'capacity: ' prefix, got {final.error!r}" + ) + assert "capacity-unavailable" in final.error + + @pytest.mark.asyncio + async def test_ensure_workflow_should_record_provisioner_prefix_for_generic_exception( + self, mock_db, tmp_cache, tmp_workdir, no_wool_dispatch, mocker + ): + """Test that non-retryable provisioner failures use the 'provisioner: ' prefix. + + Given: + A provisioner whose ``request`` raises a generic + ``RuntimeError`` (a misconfiguration or unexpected boto + failure, not a retryable transient). + When: + ``ensure_workflow`` is awaited. + Then: + The job should reach ``FAILED`` with an error string that + starts with ``"provisioner: "`` so clients can distinguish + "retry me later" from "the provisioner crashed". + """ + # Arrange + _install_jobs_index(mock_db) + registry = ProcessorRegistry() + registry.register(_StubProcessor()) + provisioner = mocker.AsyncMock(spec=EcsProvisioner) + provisioner.request.side_effect = RuntimeError("misconfigured cluster") + executor = WoolExecutor( + mock_db, + tmp_cache, + tmp_cache.root, + registry, + workdir_root=tmp_workdir, + provisioner=provisioner, + ) + + # Act + record, _ = await executor.ensure_workflow(_file_meta()) + await _wait_for_terminal(mock_db, record.job_id) + + # Assert + final = await get_job(mock_db, record.job_id) + assert final is not None + assert final.status == JobStatus.FAILED + assert final.error is not None + assert final.error.startswith("provisioner: "), ( + f"expected 'provisioner: ' prefix, got {final.error!r}" + ) + assert "misconfigured cluster" in final.error diff --git a/tests/test_workflows/test_provisioner.py b/tests/test_workflows/test_provisioner.py new file mode 100644 index 0000000..a44c37b --- /dev/null +++ b/tests/test_workflows/test_provisioner.py @@ -0,0 +1,398 @@ +"""Tests for EcsProvisioner.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from cfdb.workflows.provisioner import RetryableProvisionerError, EcsProvisioner + + +class _FakeEcsClient: + """In-memory ECS client recording RunTask calls for assertions.""" + + def __init__(self, *, response: dict | None = None, raise_on_call: Exception | None = None) -> None: + self.calls: list[dict] = [] + self._response = response or { + "tasks": [{"taskArn": "arn:aws:ecs:::task/cluster/abc"}], + "failures": [], + } + self._raise = raise_on_call + + def run_task(self, **kwargs): + self.calls.append(kwargs) + if self._raise is not None: + raise self._raise + return self._response + + +class _SimpleGatedClient(_FakeEcsClient): + """run_task blocks on a threading.Event until released by the test.""" + + def __init__(self) -> None: + super().__init__() + import threading + + self._gate = threading.Event() + self._call_started = threading.Event() + + def run_task(self, **kwargs): + self.calls.append(kwargs) + self._call_started.set() + # Block the worker thread until the test releases the gate. + self._gate.wait() + return self._response + + def release(self) -> None: + self._gate.set() + + +def _client_error(code: str) -> Exception: + """Construct a botocore ClientError with a structured response dict.""" + from botocore.exceptions import ClientError + + return ClientError( + {"Error": {"Code": code, "Message": "simulated"}, "ResponseMetadata": {"HTTPStatusCode": 500}}, + "RunTask", + ) + + +class TestEcsProvisioner: + def test___init___without_cluster(self): + """Test that constructing EcsProvisioner without a cluster fails fast. + + Given: + No cluster name. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError so misconfiguration is caught at boot. + """ + # Act & assert + with pytest.raises(ValueError, match="cluster"): + EcsProvisioner( + cluster="", + task_definition="worker", + subnets=["subnet-1"], + client=_FakeEcsClient(), + ) + + def test___init___without_task_definition(self): + """Test that omitting task_definition raises ValueError. + + Given: + A cluster but no task definition. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError to surface the misconfiguration. + """ + # Act & assert + with pytest.raises(ValueError, match="task_definition"): + EcsProvisioner( + cluster="c", + task_definition="", + subnets=["subnet-1"], + client=_FakeEcsClient(), + ) + + def test___init___without_subnets(self): + """Test that constructing without subnets raises ValueError. + + Given: + A cluster and task definition but no subnets. + When: + EcsProvisioner is constructed. + Then: + It should raise ValueError because awsvpc requires at least one. + """ + # Act & assert + with pytest.raises(ValueError, match="subnet"): + EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=[], + client=_FakeEcsClient(), + ) + + @pytest.mark.asyncio + async def test_request_with_single_caller(self): + """Test that request returns the launched task ARNs. + + Given: + A provisioner backed by a fake client returning one task ARN. + When: + request is awaited once. + Then: + It should return the list of ARNs and have invoked RunTask once + with the configured cluster, family, and awsvpc network config. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + security_groups=["sg-1"], + client=client, + ) + + # Act + arns = await provisioner.request(dedup_key="wf-1") + + # Assert + assert arns == ["arn:aws:ecs:::task/cluster/abc"] + assert len(client.calls) == 1 + call = client.calls[0] + assert call["cluster"] == "c" + assert call["taskDefinition"] == "worker" + assert call["launchType"] == "FARGATE" + awsvpc = call["networkConfiguration"]["awsvpcConfiguration"] + assert awsvpc["subnets"] == ["subnet-1"] + assert awsvpc["securityGroups"] == ["sg-1"] + + @pytest.mark.asyncio + async def test_request_with_concurrent_dedup_key_collisions(self): + """Test that concurrent calls with the same dedup_key share one RunTask. + + Given: + A provisioner whose underlying RunTask call is gated on a + threading.Event so the first caller stays in-flight. + When: + Two coroutines call request concurrently with the same dedup_key. + Then: + Only one RunTask invocation should be issued, and both callers + should observe identical ARNs. + """ + # Arrange + client = _SimpleGatedClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + first = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Wait until the first call is mid-flight before starting the + # second so the dedup map is populated. + await asyncio.to_thread(client._call_started.wait) + second = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Release the gated client so both callers can return. + client.release() + first_arns, second_arns = await asyncio.gather(first, second) + + # Assert + assert first_arns == second_arns == ["arn:aws:ecs:::task/cluster/abc"] + assert len(client.calls) == 1 + + @pytest.mark.asyncio + async def test_request_with_distinct_dedup_keys(self): + """Test that distinct dedup keys produce independent RunTask calls. + + Given: + A provisioner backed by a fake client. + When: + request is awaited twice with different dedup_keys. + Then: + Two RunTask invocations should be issued. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + await provisioner.request(dedup_key="wf-a") + await provisioner.request(dedup_key="wf-b") + + # Assert + assert len(client.calls) == 2 + + @pytest.mark.asyncio + async def test_request_with_capacity_client_error(self): + """Test that capacity ClientErrors map to RetryableProvisionerError. + + Given: + A provisioner whose underlying client raises a Capacity error. + When: + request is awaited. + Then: + It should raise RetryableProvisionerError so callers can surface it as a + retryable terminal failure. + """ + # Arrange + client = _FakeEcsClient(raise_on_call=_client_error("CapacityProviderException")) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act & assert + with pytest.raises(RetryableProvisionerError): + await provisioner.request(dedup_key="wf-1") + + @pytest.mark.asyncio + async def test_request_with_capacity_failure_in_response(self): + """Test that capacity failures inside the RunTask response also raise. + + Given: + A provisioner whose RunTask response carries a failures entry + whose reason includes RESOURCE: tokens. + When: + request is awaited. + Then: + It should raise RetryableProvisionerError rather than treating the call + as a success with zero ARNs. + """ + # Arrange + client = _FakeEcsClient( + response={ + "tasks": [], + "failures": [{"reason": "RESOURCE:CPU"}], + } + ) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act & assert + with pytest.raises(RetryableProvisionerError): + await provisioner.request(dedup_key="wf-1") + + @pytest.mark.asyncio + async def test_request_with_partial_failure_preserves_arn(self): + """Test that a launched ARN is preserved even when failures[] is non-empty. + + Given: + A provisioner whose RunTask response carries both a launched + taskArn and a non-retryable failures entry. + When: + request is awaited. + Then: + It should return the launched ARN and log the failure rather + than discarding the worker that is already running. + """ + # Arrange + client = _FakeEcsClient( + response={ + "tasks": [{"taskArn": "arn:aws:ecs:::task/cluster/abc"}], + "failures": [{"reason": "secondary placement warning"}], + } + ) + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + + # Act + arns = await provisioner.request(dedup_key="wf-1") + + # Assert + assert arns == ["arn:aws:ecs:::task/cluster/abc"] + + @pytest.mark.asyncio + async def test_request_with_done_cached_task_launches_fresh_run(self): + """Test that a stale completed in-flight entry is replaced on next request. + + Given: + A provisioner where the first request completed and the dedup + slot still holds the done task (simulating a release-slot race + where the post-completion pop was skipped). + When: + A second request with the same dedup_key is awaited. + Then: + A fresh RunTask is issued rather than the second caller + attaching to the already-finished task. + """ + # Arrange + client = _FakeEcsClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + first = await provisioner.request(dedup_key="wf-shared") + # Simulate the post-completion pop having been skipped (e.g. by a + # CancelledError that re-raised out of the shielded release): put + # the done task back into the dedup map. + done_task = asyncio.create_task(_resolved(first)) + await done_task + provisioner._in_flight["wf-shared"] = done_task + + # Act + second = await provisioner.request(dedup_key="wf-shared") + + # Assert + assert second == first + assert len(client.calls) == 2 + + @pytest.mark.asyncio + async def test_aclose_cancels_in_flight_and_clears_map(self): + """Test that aclose cancels in-flight tasks and clears the dedup map. + + Given: + A provisioner with one in-flight RunTask whose underlying + boto3 call is mid-flight on the ``asyncio.to_thread`` worker. + When: + aclose is awaited. + Then: + The asyncio task is cancelled, the dedup map is empty, and + the original caller observes CancelledError rather than a hang. + """ + # Arrange — use a client that simulates a slow round-trip via a + # short ``time.sleep`` in the worker thread. ``asyncio.to_thread`` + # cannot cancel the running thread, but the wrapping task can be + # cancelled — which is what ``aclose`` does. The short sleep + # makes the thread exit promptly so pytest's executor join at + # session teardown does not hit a 300 s timeout. + import time + + class _SlowClient(_FakeEcsClient): + def run_task(self, **kwargs): + self.calls.append(kwargs) + time.sleep(0.5) + return self._response + + client = _SlowClient() + provisioner = EcsProvisioner( + cluster="c", + task_definition="worker", + subnets=["subnet-1"], + client=client, + ) + caller = asyncio.create_task(provisioner.request(dedup_key="wf-1")) + # Yield so the dedup-registration task is scheduled and the + # boto thread call has begun. + await asyncio.sleep(0.05) + + # Act + await provisioner.aclose() + + # Assert + assert provisioner._in_flight == {} + with pytest.raises(asyncio.CancelledError): + await caller + + +async def _resolved(value): + """Helper coroutine that immediately returns ``value``. + + Used by the dedup-self-heal test to construct a real done task that + holds the same return value as the first request. + """ + return value diff --git a/tests/test_workflows/test_worker_main.py b/tests/test_workflows/test_worker_main.py new file mode 100644 index 0000000..e70c3e8 --- /dev/null +++ b/tests/test_workflows/test_worker_main.py @@ -0,0 +1,133 @@ +"""Tests for the ECS worker entrypoint argument parsing.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from click.testing import CliRunner + +from cfdb.workflows import worker_main + + +def _invoke(args: list[str]) -> tuple[int, dict[str, object]]: + """Run ``worker_main.main`` with ``args``, capturing the ``serve`` kwargs. + + Returns ``(exit_code, captured_kwargs)``. The real ``serve`` is patched + out — the entrypoint always calls ``raise SystemExit(asyncio.run(serve(...)))``, + so swapping ``asyncio.run`` for ``lambda coro: coro.close() or 0`` plus + stubbing ``serve`` keeps the test from actually booting a worker. + """ + captured: dict[str, object] = {} + + async def _fake_serve(**kwargs: object) -> int: + captured.update(kwargs) + return 0 + + runner = CliRunner() + with patch.object(worker_main, "serve", _fake_serve): + result = runner.invoke(worker_main.main, args, standalone_mode=True) + return result.exit_code, captured + + +class TestMainCli: + def test_main_uses_documented_defaults_when_no_args_or_env(self, monkeypatch): + """Test that bare invocation surfaces the documented defaults. + + Given: + No CLI arguments and no overriding environment variables. + When: + ``worker_main.main`` is invoked. + Then: + ``serve`` should be called with the documented worker port, + health port, max lifetime, and drain-grace defaults so a + bare container ``CMD`` works. + """ + # Arrange — clear any env overrides so defaults apply + for var in ( + "CFDB_WORKER_GRPC_PORT", + "CFDB_WORKER_HEALTH_PORT", + "CFDB_WORKER_MAX_LIFETIME_SECONDS", + "CFDB_WORKER_DRAIN_GRACE_SECONDS", + ): + monkeypatch.delenv(var, raising=False) + + # Act + exit_code, captured = _invoke([]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == worker_main.DEFAULT_WORKER_PORT + assert captured["health_port"] == worker_main.DEFAULT_HEALTH_PORT + assert captured["max_lifetime_seconds"] == worker_main.DEFAULT_MAX_LIFETIME_SECONDS + assert captured["drain_grace_seconds"] == worker_main.DEFAULT_DRAIN_GRACE_SECONDS + + def test_main_with_env_overrides(self, monkeypatch): + """Test that environment variables override the defaults. + + Given: + ``CFDB_WORKER_GRPC_PORT`` and friends set to non-default values. + When: + ``worker_main.main`` is invoked with no CLI flags. + Then: + ``serve`` should receive the env-driven values. + """ + # Arrange + monkeypatch.setenv("CFDB_WORKER_GRPC_PORT", "60001") + monkeypatch.setenv("CFDB_WORKER_HEALTH_PORT", "9001") + monkeypatch.setenv("CFDB_WORKER_MAX_LIFETIME_SECONDS", "1800") + monkeypatch.setenv("CFDB_WORKER_DRAIN_GRACE_SECONDS", "10") + + # Act + exit_code, captured = _invoke([]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == 60001 + assert captured["health_port"] == 9001 + assert captured["max_lifetime_seconds"] == 1800.0 + assert captured["drain_grace_seconds"] == 10.0 + + def test_main_cli_flags_override_env_vars(self, monkeypatch): + """Test that CLI flags win over environment variables. + + Given: + ``CFDB_WORKER_GRPC_PORT`` set in the environment. + When: + ``worker_main.main`` is invoked with an explicit + ``--worker-port`` CLI flag. + Then: + ``serve`` should receive the CLI-supplied value. + """ + # Arrange + monkeypatch.setenv("CFDB_WORKER_GRPC_PORT", "60001") + + # Act + exit_code, captured = _invoke(["--worker-port", "55555"]) + + # Assert + assert exit_code == 0 + assert captured["worker_port"] == 55555 + + def test_main_rejects_out_of_range_worker_port(self, monkeypatch): + """Test that a port outside [1, 65535] is rejected at parse time. + + Given: + A ``--worker-port`` value above 65535. + When: + ``worker_main.main`` is invoked. + Then: + Click should exit non-zero before reaching ``serve``, so the + failure surfaces clearly rather than as an opaque bind error + later. + """ + # Arrange + for var in ("CFDB_WORKER_GRPC_PORT",): + monkeypatch.delenv(var, raising=False) + + # Act + exit_code, captured = _invoke(["--worker-port", "99999"]) + + # Assert + assert exit_code != 0 + assert "worker_port" not in captured