Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions src/radical/asyncflow/workflow_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def __init__(
self._dependency_count = {}
self._component_change_event = asyncio.Event()

# Block task registry: uid -> asyncio.Task for running execute_block coroutines
self._block_asyncio_tasks: dict[str, asyncio.Task] = {}
# Block member registry: block_uid -> set of component UIDs registered within it
self._block_members: dict[str, set[str]] = {}

self.task_states_map = self.backend.get_task_states_map()

# Define decorators
Expand Down Expand Up @@ -606,6 +611,10 @@ def _handle_flow_component_registration(
def wrapper(*args, **kwargs):
# Create async future - we only support async
comp_fut = asyncio.Future()
comp_fut.state = "PENDING"
Comment on lines 612 to +614

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In standard CPython, asyncio.Future is implemented in C and does not have a __dict__ (or uses __slots__ in its pure Python fallback). Attempting to set arbitrary attributes like state, id, or cancel directly on an asyncio.Future instance will raise an AttributeError on standard runtimes.

To make this robust and fully compatible with standard Python runtimes, we should subclass asyncio.Future (e.g., class FlowFuture(asyncio.Future): pass) and instantiate that instead. Subclasses automatically get a __dict__ in Python, allowing arbitrary attribute assignment.

Suggested change
# Create async future - we only support async
comp_fut = asyncio.Future()
comp_fut.state = "PENDING"
# Create async future - we only support async
class FlowFuture(asyncio.Future):
pass
comp_fut = FlowFuture()
comp_fut.state = 'PENDING'

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not releated


# Extract call-time workflow_id before storing kwargs or calling the function
explicit_workflow_id = kwargs.pop("workflow_id", None)

comp_desc = {
"args": args,
Expand All @@ -615,6 +624,7 @@ def wrapper(*args, **kwargs):
"task_backend_specific_kwargs": task_backend_specific_kwargs or {},
"target_backend": target_backend,
"capture_stdio": capture_stdio,
"_explicit_workflow_id": explicit_workflow_id,
}

# Only handle async functions
Expand Down Expand Up @@ -703,7 +713,8 @@ def _register_component(
# make sure not to specify both func and executable at the same time
comp_desc["name"] = comp_desc["function"].__name__
comp_desc["uid"] = self._assign_uid(prefix=comp_type)
comp_desc["workflow_id"] = self._workflow_id_ctx.get()
# call-time workflow_id takes precedence over the ContextVar
comp_desc["workflow_id"] = comp_desc.pop("_explicit_workflow_id", None) or self._workflow_id_ctx.get()

if task_type == EXECUTABLE:
comp_desc[FUNCTION] = None # Clear function since we're using executable
Expand Down Expand Up @@ -760,8 +771,25 @@ def _register_component(
self._update_dependency_tracking(comp_desc["uid"])
self._component_change_event.set()

# Setup cancel hook
comp_fut.cancel = self._setup_future_cancel_hook(comp_fut, comp_desc["uid"])
# Track block membership: if this component is registered from within a block's
# execution context, record it so it gets cancelled when the block is cancelled.
# Read ContextVar directly — not comp_desc["workflow_id"] which may be overridden
# by an explicit call-time kwarg (a telemetry label, not a membership signal).
parent_block_uid = self._workflow_id_ctx.get()
if (
parent_block_uid
and parent_block_uid in self.components
and self.components[parent_block_uid]["type"] == BLOCK
):
self._block_members.setdefault(parent_block_uid, set()).add(comp_desc["uid"])
comp_fut.add_done_callback(
lambda _, buid=parent_block_uid, tuid=comp_desc["uid"]: self._remove_member(buid, tuid)
)

# Patch cancel for tasks only; blocks are cancelled via a done_callback
# installed in _submit_blocks that cancels the underlying asyncio.Task directly.
if comp_type != BLOCK:
comp_fut.cancel = self._setup_future_cancel_hook(comp_fut, comp_desc["uid"])

self._emit(
"TaskCreated",
Expand Down Expand Up @@ -827,7 +855,7 @@ def patched_cancel(*args, **kwargs):
else:
# Task is pending -> cancel locally
logger.info(f"Cancellation requested for {uid} (pending) locally")
return fut.original_cancel
return fut.original_cancel()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues here:

  1. fut.original_cancel is called without any arguments. In modern Python asyncio, cancel() can accept an optional msg argument (e.g., fut.cancel("reason")). We should forward *args and **kwargs to original_cancel to avoid dropping these arguments.
  2. When a pending task is cancelled locally, its state attribute is never updated to "CANCELLED". We should set fut.state = "CANCELLED" if the cancellation is successful.
Suggested change
else:
# Task is pending -> cancel locally
logger.info(f"Cancellation requested for {uid} (pending) locally")
return fut.original_cancel
return fut.original_cancel()
else:
# Task is pending -> cancel locally
logger.info(f'Cancellation requested for {uid} (pending) locally')
cancelled = fut.original_cancel(*args, **kwargs)
if cancelled:
fut.state = 'CANCELLED'
return cancelled

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed


return patched_cancel

Expand Down Expand Up @@ -926,6 +954,7 @@ def _clear_internal_records(self):
self._ready_queue.clear()
self._dependents_map.clear()
self._dependency_count.clear()
self._block_members.clear()
Comment on lines 963 to +966

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _block_asyncio_tasks registry is not cleared in _clear_internal_records. To prevent reference leaks and ensure a clean state when resetting the engine, we should clear it here.

Suggested change
self._ready_queue.clear()
self._dependents_map.clear()
self._dependency_count.clear()
self._block_members.clear()
self._ready_queue.clear()
self._dependents_map.clear()
self._dependency_count.clear()
self._block_members.clear()
self._block_asyncio_tasks.clear()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed


reset_uid_counter()

Expand Down Expand Up @@ -954,6 +983,14 @@ def _notify_dependents(self, comp_uid: str):
if comp_uid in self._dependency_count:
del self._dependency_count[comp_uid]

def _remove_member(self, block_uid: str, task_uid: str) -> None:
"""Remove a completed task from its parent block's member set."""
members = self._block_members.get(block_uid)
if members is not None:
members.discard(task_uid)
if not members:
del self._block_members[block_uid]

def _create_dependency_failure_exception(self, comp_desc: dict, failed_deps: list):
"""Create a DependencyFailureError exception that shows both the immediate
failure and the root cause from failed dependencies.
Expand Down Expand Up @@ -1355,15 +1392,28 @@ async def _submit_blocks(self, blocks: list):
- Relies on `execute_block` to handle the actual function call and future
"""
for block in blocks:
args = block["args"]
kwargs = block["kwargs"]
func = block["function"]
block_fut = self.components[block["uid"]]["future"]

# Execute the block function as a coroutine
asyncio.create_task(
self.execute_block(block_fut, func, *args, **kwargs), name=block["uid"]
block_uid = block["uid"]
block_fut = self.components[block_uid]["future"]
t = asyncio.create_task(
self.execute_block(block_fut, block["function"], *block["args"], **block["kwargs"]),
name=block_uid,
)
self._block_asyncio_tasks[block_uid] = t
# Remove from registry when the asyncio.Task finishes (any outcome)
t.add_done_callback(lambda _, uid=block_uid: self._block_asyncio_tasks.pop(uid, None))
# Wire cancellation: if block_fut is cancelled externally after submission,
# propagate to the asyncio.Task.
def _on_block_fut_done(f, task=t, buid=block_uid):
if f.cancelled():
task.cancel()
f.state = "CANCELLED"
members = self._block_members.pop(buid, None)
if members:
for member_uid in members:
comp = self.components.get(member_uid)
if comp and not comp["future"].done():
comp["future"].cancel()
block_fut.add_done_callback(_on_block_fut_done)
Comment on lines +1418 to +1437

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues with block state tracking and cleanup:

  1. The block future state is never updated to "RUNNING", "DONE", or "FAILED" during its execution lifecycle (it remains "PENDING" forever unless cancelled). We should update the state when the block starts running and when it completes or fails.
  2. If a block completes normally or fails, its entry in self._block_members is never cleaned up, which leads to a memory leak. We should pop the block from self._block_members when the block future is done.
            self._block_asyncio_tasks[block_uid] = t
            block_fut.state = 'RUNNING'
            # Remove from registry when the asyncio.Task finishes (any outcome)
            t.add_done_callback(lambda _, uid=block_uid: self._block_asyncio_tasks.pop(uid, None))
            # Wire cancellation and state transitions: if block_fut is done,
            # propagate cancellation and update state.
            def _on_block_fut_done(f, task=t, buid=block_uid):
                if f.cancelled():
                    task.cancel()
                    f.state = 'CANCELLED'
                    members = self._block_members.pop(buid, None)
                    if members:
                        for member_uid in members:
                            comp = self.components.get(member_uid)
                            if comp and not comp['future'].done():
                                comp['future'].cancel()
                elif f.exception() is not None:
                    f.state = 'FAILED'
                    self._block_members.pop(buid, None)
                else:
                    f.state = 'DONE'
                    self._block_members.pop(buid, None)
            block_fut.add_done_callback(_on_block_fut_done)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed


async def execute_block(
self, block_fut: asyncio.Future, func: Callable, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -1423,6 +1473,7 @@ def handle_task_success(self, task: dict, task_fut: asyncio.Future) -> None:
task_fut.set_result(task["return_value"])
else:
task_fut.set_result(task["stdout"])
task_fut.state = "DONE"
else:
logger.warning(
f'Attempted to handle an already finished task "{task["uid"]}"'
Expand Down Expand Up @@ -1482,18 +1533,18 @@ def handle_task_failure(
exception = RuntimeError(str(original_error))

task_fut.set_exception(exception)
task_fut.state = "FAILED"

def handle_task_cancellation(self, task: dict, task_fut: asyncio.Future):
"""Handle task cancellation."""
if task_fut.done():
logger.warning(
f'Attempted to handle an already cancelled task "{task["uid"]}"'
)
return
return # already resolved — idempotent, nothing to do

# Restore original cancel method
task_fut.cancel = task_fut.original_cancel
return task_fut.cancel()
result = task_fut.cancel()
task_fut.state = "CANCELLED"
return result

@typeguard.typechecked
def task_callbacks(
Expand Down Expand Up @@ -1596,6 +1647,7 @@ def wait_and_set():
# implicit: when a coroutine that awaits the future
# is scheduled and started by the event loop, that's
# when the "work" is running.
task_fut.state = "RUNNING"
if self._telemetry is not None:
now = time.time()
self._task_start_times[uid] = now
Expand Down
199 changes: 199 additions & 0 deletions tests/unit/test_block_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Tests for block submission, task registry, and member cancellation propagation."""

import asyncio
from concurrent.futures import ThreadPoolExecutor

import pytest

from radical.asyncflow import NoopExecutionBackend, WorkflowEngine
from radical.asyncflow.backends import LocalExecutionBackend


async def _make_engine():
return await WorkflowEngine.create(backend=NoopExecutionBackend())


async def _make_local_engine():
backend = await LocalExecutionBackend(ThreadPoolExecutor(max_workers=8))
return await WorkflowEngine.create(backend=backend)


# ---------------------------------------------------------------------------
# _block_asyncio_tasks registry
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_block_asyncio_tasks_registry_populated():
"""_block_asyncio_tasks must hold an entry for a block while it is running."""
engine = await _make_engine()
block_started = asyncio.Event()
block_release = asyncio.Event()

@engine.block
async def my_block():
block_started.set()
await block_release.wait()

my_block()
await asyncio.sleep(0.05)
await block_started.wait()

block_uid = next(iter(engine.components))
assert block_uid in engine._block_asyncio_tasks

block_release.set()
await asyncio.sleep(0.05)

assert block_uid not in engine._block_asyncio_tasks

await engine.shutdown()


@pytest.mark.asyncio
async def test_block_asyncio_tasks_registry_cleared_on_completion():
"""Registry must be empty after a block completes normally."""
engine = await _make_engine()

@engine.block
async def quick_block():
return 42

quick_block()
await asyncio.sleep(0.1)

assert len(engine._block_asyncio_tasks) == 0

await engine.shutdown()


# ---------------------------------------------------------------------------
# _block_members — member registration and cleanup
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_block_members_cleared_on_normal_completion():
"""_block_members must be empty after a block and its tasks complete normally."""
engine = await _make_engine()

@engine.block
async def normal_block():
@engine.function_task
async def inner():
return "ok"

await inner()

block_fut = normal_block()
await asyncio.sleep(0.2)

assert block_fut.done() and not block_fut.cancelled()
assert len(engine._block_members) == 0

await engine.shutdown()


@pytest.mark.asyncio
async def test_completed_tasks_removed_from_block_members():
"""Tasks that finish before block cancellation are cleaned up by their done_callback."""
engine = await _make_engine()
block_gate = asyncio.Event()

@engine.block
async def my_block():
@engine.function_task
async def quick_task():
return "done"

await quick_task() # NoopBackend resolves immediately
await block_gate.wait() # keep block alive for inspection

my_block()
await asyncio.sleep(0.15)

block_uid = next(uid for uid, c in engine.components.items() if c["type"] == "block")
assert len(engine._block_members.get(block_uid, set())) == 0

block_gate.set()
await asyncio.sleep(0.05)
await engine.shutdown()


# ---------------------------------------------------------------------------
# Cancellation propagation to block members
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_block_cancellation_propagates_to_member_tasks():
"""All tasks registered inside a block before cancellation must be cancelled."""
engine = await _make_local_engine()

@engine.block
async def my_block():
@engine.function_task
async def t1():
await asyncio.sleep(1)
return "t1"

@engine.function_task
async def t2(dep):
await asyncio.sleep(1)
return "t2"

@engine.function_task
async def t3(dep1, dep2):
await asyncio.sleep(1)
return "t3"

f1 = t1()
f2 = t2(f1)
f3 = t3(f1, f2)
await f3

block_fut = my_block()
await asyncio.sleep(0.4)

block_fut.cancel()
await asyncio.sleep(0.3)

assert block_fut.cancelled()
task_comps = [c for c in engine.components.values() if c["type"] == "task"]
assert len(task_comps) == 3
assert all(c["future"].cancelled() for c in task_comps)

await engine.shutdown()


@pytest.mark.asyncio
async def test_nested_block_cancellation_propagates():
"""Cancelling an outer block propagates recursively to inner blocks and their tasks."""
engine = await _make_local_engine()

@engine.block
async def outer_block():
@engine.block
async def inner_block():
@engine.function_task
async def deep_task():
await asyncio.sleep(1)
return "never"

f = deep_task()
await f

ib = inner_block()
await ib

outer_fut = outer_block()
await asyncio.sleep(0.5)

outer_fut.cancel()
await asyncio.sleep(0.3)

assert outer_fut.cancelled()
deep_comp = next(c for c in engine.components.values() if c["type"] == "task")
assert deep_comp["future"].cancelled()

await engine.shutdown()
Loading
Loading