-
Notifications
You must be signed in to change notification settings - Fork 1
Enable and Improve blocks and their cancellation #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
505da9e
c622813
2c86646
d0c0e07
9e5d568
de74ec9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -107,6 +107,11 @@ def __init__( | |||||||||||||||||||||||||
| self._dependency_count = {} | ||||||||||||||||||||||||||
| self._component_change_event = asyncio.Event() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Block task registry: uid -> asyncio.Task for running execute_block coroutines | ||||||||||||||||||||||||||
| self._block_asyncio_tasks: dict[str, asyncio.Task] = {} | ||||||||||||||||||||||||||
| # Block member registry: block_uid -> set of component UIDs registered within it | ||||||||||||||||||||||||||
| self._block_members: dict[str, set[str]] = {} | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.task_states_map = self.backend.get_task_states_map() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Define decorators | ||||||||||||||||||||||||||
|
|
@@ -606,6 +611,10 @@ def _handle_flow_component_registration( | |||||||||||||||||||||||||
| def wrapper(*args, **kwargs): | ||||||||||||||||||||||||||
| # Create async future - we only support async | ||||||||||||||||||||||||||
| comp_fut = asyncio.Future() | ||||||||||||||||||||||||||
| comp_fut.state = "PENDING" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Extract call-time workflow_id before storing kwargs or calling the function | ||||||||||||||||||||||||||
| explicit_workflow_id = kwargs.pop("workflow_id", None) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| comp_desc = { | ||||||||||||||||||||||||||
| "args": args, | ||||||||||||||||||||||||||
|
|
@@ -615,6 +624,7 @@ def wrapper(*args, **kwargs): | |||||||||||||||||||||||||
| "task_backend_specific_kwargs": task_backend_specific_kwargs or {}, | ||||||||||||||||||||||||||
| "target_backend": target_backend, | ||||||||||||||||||||||||||
| "capture_stdio": capture_stdio, | ||||||||||||||||||||||||||
| "_explicit_workflow_id": explicit_workflow_id, | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Only handle async functions | ||||||||||||||||||||||||||
|
|
@@ -703,7 +713,8 @@ def _register_component( | |||||||||||||||||||||||||
| # make sure not to specify both func and executable at the same time | ||||||||||||||||||||||||||
| comp_desc["name"] = comp_desc["function"].__name__ | ||||||||||||||||||||||||||
| comp_desc["uid"] = self._assign_uid(prefix=comp_type) | ||||||||||||||||||||||||||
| comp_desc["workflow_id"] = self._workflow_id_ctx.get() | ||||||||||||||||||||||||||
| # call-time workflow_id takes precedence over the ContextVar | ||||||||||||||||||||||||||
| comp_desc["workflow_id"] = comp_desc.pop("_explicit_workflow_id", None) or self._workflow_id_ctx.get() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if task_type == EXECUTABLE: | ||||||||||||||||||||||||||
| comp_desc[FUNCTION] = None # Clear function since we're using executable | ||||||||||||||||||||||||||
|
|
@@ -760,8 +771,25 @@ def _register_component( | |||||||||||||||||||||||||
| self._update_dependency_tracking(comp_desc["uid"]) | ||||||||||||||||||||||||||
| self._component_change_event.set() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Setup cancel hook | ||||||||||||||||||||||||||
| comp_fut.cancel = self._setup_future_cancel_hook(comp_fut, comp_desc["uid"]) | ||||||||||||||||||||||||||
| # Track block membership: if this component is registered from within a block's | ||||||||||||||||||||||||||
| # execution context, record it so it gets cancelled when the block is cancelled. | ||||||||||||||||||||||||||
| # Read ContextVar directly — not comp_desc["workflow_id"] which may be overridden | ||||||||||||||||||||||||||
| # by an explicit call-time kwarg (a telemetry label, not a membership signal). | ||||||||||||||||||||||||||
| parent_block_uid = self._workflow_id_ctx.get() | ||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||
| parent_block_uid | ||||||||||||||||||||||||||
| and parent_block_uid in self.components | ||||||||||||||||||||||||||
| and self.components[parent_block_uid]["type"] == BLOCK | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| self._block_members.setdefault(parent_block_uid, set()).add(comp_desc["uid"]) | ||||||||||||||||||||||||||
| comp_fut.add_done_callback( | ||||||||||||||||||||||||||
| lambda _, buid=parent_block_uid, tuid=comp_desc["uid"]: self._remove_member(buid, tuid) | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Patch cancel for tasks only; blocks are cancelled via a done_callback | ||||||||||||||||||||||||||
| # installed in _submit_blocks that cancels the underlying asyncio.Task directly. | ||||||||||||||||||||||||||
| if comp_type != BLOCK: | ||||||||||||||||||||||||||
| comp_fut.cancel = self._setup_future_cancel_hook(comp_fut, comp_desc["uid"]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self._emit( | ||||||||||||||||||||||||||
| "TaskCreated", | ||||||||||||||||||||||||||
|
|
@@ -827,7 +855,7 @@ def patched_cancel(*args, **kwargs): | |||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| # Task is pending -> cancel locally | ||||||||||||||||||||||||||
| logger.info(f"Cancellation requested for {uid} (pending) locally") | ||||||||||||||||||||||||||
| return fut.original_cancel | ||||||||||||||||||||||||||
| return fut.original_cancel() | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two issues here:
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return patched_cancel | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -926,6 +954,7 @@ def _clear_internal_records(self): | |||||||||||||||||||||||||
| self._ready_queue.clear() | ||||||||||||||||||||||||||
| self._dependents_map.clear() | ||||||||||||||||||||||||||
| self._dependency_count.clear() | ||||||||||||||||||||||||||
| self._block_members.clear() | ||||||||||||||||||||||||||
|
Comment on lines
963
to
+966
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| reset_uid_counter() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -954,6 +983,14 @@ def _notify_dependents(self, comp_uid: str): | |||||||||||||||||||||||||
| if comp_uid in self._dependency_count: | ||||||||||||||||||||||||||
| del self._dependency_count[comp_uid] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _remove_member(self, block_uid: str, task_uid: str) -> None: | ||||||||||||||||||||||||||
| """Remove a completed task from its parent block's member set.""" | ||||||||||||||||||||||||||
| members = self._block_members.get(block_uid) | ||||||||||||||||||||||||||
| if members is not None: | ||||||||||||||||||||||||||
| members.discard(task_uid) | ||||||||||||||||||||||||||
| if not members: | ||||||||||||||||||||||||||
| del self._block_members[block_uid] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _create_dependency_failure_exception(self, comp_desc: dict, failed_deps: list): | ||||||||||||||||||||||||||
| """Create a DependencyFailureError exception that shows both the immediate | ||||||||||||||||||||||||||
| failure and the root cause from failed dependencies. | ||||||||||||||||||||||||||
|
|
@@ -1355,15 +1392,28 @@ async def _submit_blocks(self, blocks: list): | |||||||||||||||||||||||||
| - Relies on `execute_block` to handle the actual function call and future | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| for block in blocks: | ||||||||||||||||||||||||||
| args = block["args"] | ||||||||||||||||||||||||||
| kwargs = block["kwargs"] | ||||||||||||||||||||||||||
| func = block["function"] | ||||||||||||||||||||||||||
| block_fut = self.components[block["uid"]]["future"] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Execute the block function as a coroutine | ||||||||||||||||||||||||||
| asyncio.create_task( | ||||||||||||||||||||||||||
| self.execute_block(block_fut, func, *args, **kwargs), name=block["uid"] | ||||||||||||||||||||||||||
| block_uid = block["uid"] | ||||||||||||||||||||||||||
| block_fut = self.components[block_uid]["future"] | ||||||||||||||||||||||||||
| t = asyncio.create_task( | ||||||||||||||||||||||||||
| self.execute_block(block_fut, block["function"], *block["args"], **block["kwargs"]), | ||||||||||||||||||||||||||
| name=block_uid, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self._block_asyncio_tasks[block_uid] = t | ||||||||||||||||||||||||||
| # Remove from registry when the asyncio.Task finishes (any outcome) | ||||||||||||||||||||||||||
| t.add_done_callback(lambda _, uid=block_uid: self._block_asyncio_tasks.pop(uid, None)) | ||||||||||||||||||||||||||
| # Wire cancellation: if block_fut is cancelled externally after submission, | ||||||||||||||||||||||||||
| # propagate to the asyncio.Task. | ||||||||||||||||||||||||||
| def _on_block_fut_done(f, task=t, buid=block_uid): | ||||||||||||||||||||||||||
| if f.cancelled(): | ||||||||||||||||||||||||||
| task.cancel() | ||||||||||||||||||||||||||
| f.state = "CANCELLED" | ||||||||||||||||||||||||||
| members = self._block_members.pop(buid, None) | ||||||||||||||||||||||||||
| if members: | ||||||||||||||||||||||||||
| for member_uid in members: | ||||||||||||||||||||||||||
| comp = self.components.get(member_uid) | ||||||||||||||||||||||||||
| if comp and not comp["future"].done(): | ||||||||||||||||||||||||||
| comp["future"].cancel() | ||||||||||||||||||||||||||
| block_fut.add_done_callback(_on_block_fut_done) | ||||||||||||||||||||||||||
|
Comment on lines
+1418
to
+1437
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two issues with block state tracking and cleanup:
self._block_asyncio_tasks[block_uid] = t
block_fut.state = 'RUNNING'
# Remove from registry when the asyncio.Task finishes (any outcome)
t.add_done_callback(lambda _, uid=block_uid: self._block_asyncio_tasks.pop(uid, None))
# Wire cancellation and state transitions: if block_fut is done,
# propagate cancellation and update state.
def _on_block_fut_done(f, task=t, buid=block_uid):
if f.cancelled():
task.cancel()
f.state = 'CANCELLED'
members = self._block_members.pop(buid, None)
if members:
for member_uid in members:
comp = self.components.get(member_uid)
if comp and not comp['future'].done():
comp['future'].cancel()
elif f.exception() is not None:
f.state = 'FAILED'
self._block_members.pop(buid, None)
else:
f.state = 'DONE'
self._block_members.pop(buid, None)
block_fut.add_done_callback(_on_block_fut_done)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def execute_block( | ||||||||||||||||||||||||||
| self, block_fut: asyncio.Future, func: Callable, *args: Any, **kwargs: Any | ||||||||||||||||||||||||||
|
|
@@ -1423,6 +1473,7 @@ def handle_task_success(self, task: dict, task_fut: asyncio.Future) -> None: | |||||||||||||||||||||||||
| task_fut.set_result(task["return_value"]) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| task_fut.set_result(task["stdout"]) | ||||||||||||||||||||||||||
| task_fut.state = "DONE" | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||
| f'Attempted to handle an already finished task "{task["uid"]}"' | ||||||||||||||||||||||||||
|
|
@@ -1482,18 +1533,18 @@ def handle_task_failure( | |||||||||||||||||||||||||
| exception = RuntimeError(str(original_error)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| task_fut.set_exception(exception) | ||||||||||||||||||||||||||
| task_fut.state = "FAILED" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def handle_task_cancellation(self, task: dict, task_fut: asyncio.Future): | ||||||||||||||||||||||||||
| """Handle task cancellation.""" | ||||||||||||||||||||||||||
| if task_fut.done(): | ||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||
| f'Attempted to handle an already cancelled task "{task["uid"]}"' | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
| return # already resolved — idempotent, nothing to do | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Restore original cancel method | ||||||||||||||||||||||||||
| task_fut.cancel = task_fut.original_cancel | ||||||||||||||||||||||||||
| return task_fut.cancel() | ||||||||||||||||||||||||||
| result = task_fut.cancel() | ||||||||||||||||||||||||||
| task_fut.state = "CANCELLED" | ||||||||||||||||||||||||||
| return result | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @typeguard.typechecked | ||||||||||||||||||||||||||
| def task_callbacks( | ||||||||||||||||||||||||||
|
|
@@ -1596,6 +1647,7 @@ def wait_and_set(): | |||||||||||||||||||||||||
| # implicit: when a coroutine that awaits the future | ||||||||||||||||||||||||||
| # is scheduled and started by the event loop, that's | ||||||||||||||||||||||||||
| # when the "work" is running. | ||||||||||||||||||||||||||
| task_fut.state = "RUNNING" | ||||||||||||||||||||||||||
| if self._telemetry is not None: | ||||||||||||||||||||||||||
| now = time.time() | ||||||||||||||||||||||||||
| self._task_start_times[uid] = now | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| """Tests for block submission, task registry, and member cancellation propagation.""" | ||
|
|
||
| import asyncio | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
||
| import pytest | ||
|
|
||
| from radical.asyncflow import NoopExecutionBackend, WorkflowEngine | ||
| from radical.asyncflow.backends import LocalExecutionBackend | ||
|
|
||
|
|
||
| async def _make_engine(): | ||
| return await WorkflowEngine.create(backend=NoopExecutionBackend()) | ||
|
|
||
|
|
||
| async def _make_local_engine(): | ||
| backend = await LocalExecutionBackend(ThreadPoolExecutor(max_workers=8)) | ||
| return await WorkflowEngine.create(backend=backend) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _block_asyncio_tasks registry | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_asyncio_tasks_registry_populated(): | ||
| """_block_asyncio_tasks must hold an entry for a block while it is running.""" | ||
| engine = await _make_engine() | ||
| block_started = asyncio.Event() | ||
| block_release = asyncio.Event() | ||
|
|
||
| @engine.block | ||
| async def my_block(): | ||
| block_started.set() | ||
| await block_release.wait() | ||
|
|
||
| my_block() | ||
| await asyncio.sleep(0.05) | ||
| await block_started.wait() | ||
|
|
||
| block_uid = next(iter(engine.components)) | ||
| assert block_uid in engine._block_asyncio_tasks | ||
|
|
||
| block_release.set() | ||
| await asyncio.sleep(0.05) | ||
|
|
||
| assert block_uid not in engine._block_asyncio_tasks | ||
|
|
||
| await engine.shutdown() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_asyncio_tasks_registry_cleared_on_completion(): | ||
| """Registry must be empty after a block completes normally.""" | ||
| engine = await _make_engine() | ||
|
|
||
| @engine.block | ||
| async def quick_block(): | ||
| return 42 | ||
|
|
||
| quick_block() | ||
| await asyncio.sleep(0.1) | ||
|
|
||
| assert len(engine._block_asyncio_tasks) == 0 | ||
|
|
||
| await engine.shutdown() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _block_members — member registration and cleanup | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_members_cleared_on_normal_completion(): | ||
| """_block_members must be empty after a block and its tasks complete normally.""" | ||
| engine = await _make_engine() | ||
|
|
||
| @engine.block | ||
| async def normal_block(): | ||
| @engine.function_task | ||
| async def inner(): | ||
| return "ok" | ||
|
|
||
| await inner() | ||
|
|
||
| block_fut = normal_block() | ||
| await asyncio.sleep(0.2) | ||
|
|
||
| assert block_fut.done() and not block_fut.cancelled() | ||
| assert len(engine._block_members) == 0 | ||
|
|
||
| await engine.shutdown() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_completed_tasks_removed_from_block_members(): | ||
| """Tasks that finish before block cancellation are cleaned up by their done_callback.""" | ||
| engine = await _make_engine() | ||
| block_gate = asyncio.Event() | ||
|
|
||
| @engine.block | ||
| async def my_block(): | ||
| @engine.function_task | ||
| async def quick_task(): | ||
| return "done" | ||
|
|
||
| await quick_task() # NoopBackend resolves immediately | ||
| await block_gate.wait() # keep block alive for inspection | ||
|
|
||
| my_block() | ||
| await asyncio.sleep(0.15) | ||
|
|
||
| block_uid = next(uid for uid, c in engine.components.items() if c["type"] == "block") | ||
| assert len(engine._block_members.get(block_uid, set())) == 0 | ||
|
|
||
| block_gate.set() | ||
| await asyncio.sleep(0.05) | ||
| await engine.shutdown() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Cancellation propagation to block members | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_cancellation_propagates_to_member_tasks(): | ||
| """All tasks registered inside a block before cancellation must be cancelled.""" | ||
| engine = await _make_local_engine() | ||
|
|
||
| @engine.block | ||
| async def my_block(): | ||
| @engine.function_task | ||
| async def t1(): | ||
| await asyncio.sleep(1) | ||
| return "t1" | ||
|
|
||
| @engine.function_task | ||
| async def t2(dep): | ||
| await asyncio.sleep(1) | ||
| return "t2" | ||
|
|
||
| @engine.function_task | ||
| async def t3(dep1, dep2): | ||
| await asyncio.sleep(1) | ||
| return "t3" | ||
|
|
||
| f1 = t1() | ||
| f2 = t2(f1) | ||
| f3 = t3(f1, f2) | ||
| await f3 | ||
|
|
||
| block_fut = my_block() | ||
| await asyncio.sleep(0.4) | ||
|
|
||
| block_fut.cancel() | ||
| await asyncio.sleep(0.3) | ||
|
|
||
| assert block_fut.cancelled() | ||
| task_comps = [c for c in engine.components.values() if c["type"] == "task"] | ||
| assert len(task_comps) == 3 | ||
| assert all(c["future"].cancelled() for c in task_comps) | ||
|
|
||
| await engine.shutdown() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_nested_block_cancellation_propagates(): | ||
| """Cancelling an outer block propagates recursively to inner blocks and their tasks.""" | ||
| engine = await _make_local_engine() | ||
|
|
||
| @engine.block | ||
| async def outer_block(): | ||
| @engine.block | ||
| async def inner_block(): | ||
| @engine.function_task | ||
| async def deep_task(): | ||
| await asyncio.sleep(1) | ||
| return "never" | ||
|
|
||
| f = deep_task() | ||
| await f | ||
|
|
||
| ib = inner_block() | ||
| await ib | ||
|
|
||
| outer_fut = outer_block() | ||
| await asyncio.sleep(0.5) | ||
|
|
||
| outer_fut.cancel() | ||
| await asyncio.sleep(0.3) | ||
|
|
||
| assert outer_fut.cancelled() | ||
| deep_comp = next(c for c in engine.components.values() if c["type"] == "task") | ||
| assert deep_comp["future"].cancelled() | ||
|
|
||
| await engine.shutdown() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In standard CPython,
asyncio.Futureis implemented in C and does not have a__dict__(or uses__slots__in its pure Python fallback). Attempting to set arbitrary attributes likestate,id, orcanceldirectly on anasyncio.Futureinstance will raise anAttributeErroron standard runtimes.To make this robust and fully compatible with standard Python runtimes, we should subclass
asyncio.Future(e.g.,class FlowFuture(asyncio.Future): pass) and instantiate that instead. Subclasses automatically get a__dict__in Python, allowing arbitrary attribute assignment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not releated