Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions temporalio/contrib/workflow_streams/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
WorkflowHandle,
WorkflowUpdateFailedError,
WorkflowUpdateRPCTimeoutOrCancelledError,
WorkflowUpdateStage,
)
from temporalio.converter import DataConverter, PayloadConverter
from temporalio.service import RPCError, RPCStatusCode
Expand Down Expand Up @@ -127,6 +128,10 @@ def __init__(
self._pending_seq: int = 0
self._pending_since: float | None = None
self._topic_types: dict[str, type[Any]] = {}
# Run id the most recent poll's update was admitted to. Captured before
# waiting for the outcome so a mid-poll continue-as-new can be detected by
# describing that specific run. None until the first poll is admitted.
self._polled_run_id: str | None = None

@classmethod
def create(
Expand Down Expand Up @@ -528,11 +533,18 @@ async def subscribe(
offset = from_offset
while True:
try:
result: PollResult = await self._handle.execute_update(
# Wait only for ACCEPTED so the handle (and the run id it was
# admitted to) is available before we block on the outcome; if
# the run continues-as-new mid-poll, result() fails but we still
# know which run to inspect.
handle = await self._handle.start_update(
"__temporal_workflow_stream_poll",
PollInput(topics=topic_filter, from_offset=offset),
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
result_type=PollResult,
)
self._polled_run_id = handle.workflow_run_id
result: PollResult = await handle.result()
except asyncio.CancelledError:
return
except WorkflowUpdateFailedError as e:
Expand All @@ -544,6 +556,14 @@ async def subscribe(
# base_offset).
offset = 0
continue
if cause_type == "StreamDraining":
# Workflow is detaching for continue-as-new. Back off and
# retry; the poll lands on the successor run once the
# rollover completes.
cooldown_secs = poll_cooldown.total_seconds()
if cooldown_secs > 0:
await asyncio.sleep(cooldown_secs)
continue
if cause_type == "AcceptedUpdateCompletedWorkflow":
# Workflow returned (or continued-as-new) before
# this poll's update completed. Either follow the
Expand Down Expand Up @@ -586,15 +606,32 @@ async def subscribe(
if not result.more_ready and cooldown_secs > 0:
await asyncio.sleep(cooldown_secs)

async def _describe_polled_run(self):
"""Describe the specific run the most recent poll was admitted to.

Describing that run (rather than the latest) is what lets a
continue-as-new be detected: a rolled-over run is closed with status
CONTINUED_AS_NEW, whereas the latest run would report RUNNING. Falls
back to the latest run when no run id has been captured yet, or when no
client is available to target a specific run.
"""
if self._client is not None:
return await self._client.get_workflow_handle(
self._workflow_id, run_id=self._polled_run_id
).describe()
return await self._handle.describe()

async def _follow_continue_as_new(self) -> bool:
"""Check if the workflow continued-as-new and re-target the handle.
"""Check if the polled run continued-as-new and re-target the handle.

Returns True if the handle was updated (caller should retry).
Returns True if the handle was updated (caller should retry). The
successor run id is not needed — re-targeting to an unpinned handle
makes the next poll address the latest (successor) run.
"""
if self._client is None:
return False
try:
desc = await self._handle.describe()
desc = await self._describe_polled_run()
except Exception:
return False
if desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW:
Expand All @@ -603,14 +640,14 @@ async def _follow_continue_as_new(self) -> bool:
return False

async def _workflow_in_terminal_state(self) -> bool:
"""Return True if the workflow has reached a terminal state.
"""Return True if the polled run has reached a terminal state.

Used by ``subscribe()`` to distinguish "workflow finished —
stream is done" from "wrong workflow id" when a poll RPC
returns NOT_FOUND.
"""
try:
desc = await self._handle.describe()
desc = await self._describe_polled_run()
except Exception:
return False
return desc.status in (
Expand Down
13 changes: 11 additions & 2 deletions temporalio/contrib/workflow_streams/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,18 @@ async def _on_poll(self, payload: PollInput) -> PollResult:
)

def _validate_poll(self, _payload: PollInput) -> None:
"""Reject new polls when pollers are detached for continue-as-new."""
"""Reject new polls when pollers are detached for continue-as-new.

Uses the well-known ``StreamDraining`` type so a subscriber recognizes
the rollover-in-progress and retries until its poll lands on the
successor run, rather than surfacing the rejection as an error.
"""
if self._detaching:
raise RuntimeError("Workflow pollers are detached for continue-as-new")
raise ApplicationError(
"Workflow pollers are detached for continue-as-new",
type="StreamDraining",
non_retryable=True,
)

def _on_offset(self) -> int:
"""Return the current global offset (base_offset + log length)."""
Expand Down
170 changes: 170 additions & 0 deletions tests/contrib/workflow_streams/test_workflow_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from temporalio import activity, nexus, workflow
from temporalio.client import (
Client,
WorkflowExecutionStatus,
WorkflowHandle,
WorkflowUpdateFailedError,
WorkflowUpdateStage,
Expand Down Expand Up @@ -2008,6 +2009,175 @@ async def test_continue_as_new_helper(client: Client) -> None:
await new_handle.signal(ContinueAsNewHelperWorkflow.close)


@pytest.mark.asyncio
async def test_follow_continue_as_new_describes_polled_run(client: Client) -> None:
"""Regression test for continue-as-new detection.

``_follow_continue_as_new`` must describe the *specific run the poll was
admitted to* — a rolled-over run is closed with status CONTINUED_AS_NEW,
whereas the latest (successor) run reports RUNNING. The previous
implementation described the latest run, so the check never fired and a poll
failure during a rollover stopped the subscription instead of following it.

Driving the exact poll-failure race deterministically is impractical (the
workflow drains in-flight polls before continuing-as-new), so this asserts
the helper's decision directly against a real post-rollover run.
"""
async with new_worker(client, ContinueAsNewHelperWorkflow) as worker:
handle = await client.start_workflow(
ContinueAsNewHelperWorkflow.run,
CANWorkflowInputTyped(),
id=f"workflow-stream-can-follow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
await handle.signal(
"__temporal_workflow_stream_publish",
PublishInput(
items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))],
publisher_id="pub",
sequence=1,
),
)
old_run_id = handle.result_run_id

await handle.signal(ContinueAsNewHelperWorkflow.trigger_continue)
new_handle = client.get_workflow_handle(handle.id)
await assert_eq_eventually(True, lambda: _is_different_run(handle, new_handle))

# The fix's premise: the polled (old) run reports CONTINUED_AS_NEW; the
# latest run reports RUNNING.
old_desc = await client.get_workflow_handle(
handle.id, run_id=old_run_id
).describe()
assert old_desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW
latest_desc = await client.get_workflow_handle(handle.id).describe()
assert latest_desc.status == WorkflowExecutionStatus.RUNNING

# The client follows the rollover when it describes the polled run, but
# the previous latest-run behavior (polled_run_id unset) would not.
following = WorkflowStreamClient.create(client, handle.id)
following._polled_run_id = old_run_id
assert await following._follow_continue_as_new() is True

latest_only = WorkflowStreamClient.create(client, handle.id)
latest_only._polled_run_id = None # describes the latest run, as the bug did
assert await latest_only._follow_continue_as_new() is False

await new_handle.signal(ContinueAsNewHelperWorkflow.close)


@workflow.defn
class DrainingGateWorkflow:
"""CAN workflow that detaches pollers and then *holds* in the draining state
until released, so a subscriber deterministically hits the draining poll
rejection before the rollover completes."""

@workflow.init
def __init__(self, input: CANWorkflowInputTyped) -> None:
self.stream = WorkflowStream(prior_state=input.stream_state)
self._should_continue = False
self._release = False
self._closed = False

@workflow.signal
def close(self) -> None:
self._closed = True

@workflow.signal
def trigger_continue(self) -> None:
self._should_continue = True

@workflow.signal
def release(self) -> None:
self._release = True

@workflow.run
async def run(self, _input: CANWorkflowInputTyped) -> None:
del _input
await workflow.wait_condition(lambda: self._should_continue or self._closed)
if self._closed:
return
# Detach but stay open until released, so new polls are rejected with
# StreamDraining for a deterministic window.
self.stream.detach_pollers()
await workflow.wait_condition(lambda: self._release)
await workflow.wait_condition(workflow.all_handlers_finished)
workflow.continue_as_new(
args=[CANWorkflowInputTyped(stream_state=self.stream.get_state())]
)


@pytest.mark.asyncio
async def test_subscribe_retries_while_draining(client: Client) -> None:
"""A poll rejected because the stream is draining for continue-as-new must
be retried, not surfaced as an error: the subscription stays alive through
the rollover and resumes on the successor run."""
async with new_worker(client, DrainingGateWorkflow) as worker:
handle = await client.start_workflow(
DrainingGateWorkflow.run,
CANWorkflowInputTyped(),
id=f"workflow-stream-draining-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
await handle.signal(
"__temporal_workflow_stream_publish",
PublishInput(
items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))],
publisher_id="pub",
sequence=1,
),
)

stream = WorkflowStreamClient.create(client, handle.id)
received: list[WorkflowStreamItem] = []

async def consume() -> None:
async for item in stream.subscribe(
from_offset=0, poll_cooldown=timedelta(0), result_type=bytes
):
received.append(item)

async def received_count() -> int:
return len(received)

task = asyncio.create_task(consume())
new_handle = client.get_workflow_handle(handle.id)
try:
await assert_eq_eventually(1, received_count)

# Detach; the subscriber's polls are now rejected with StreamDraining.
await handle.signal(DrainingGateWorkflow.trigger_continue)
# The subscription must keep retrying, not error out.
await asyncio.sleep(1.0)
assert not task.done(), "draining rejection must not end the subscription"

# Release: the workflow continues-as-new; the subscription resumes on
# the successor run and receives an item published there.
await handle.signal(DrainingGateWorkflow.release)
await assert_eq_eventually(
True, lambda: _is_different_run(handle, new_handle)
)
await new_handle.signal(
"__temporal_workflow_stream_publish",
PublishInput(
items=[PublishEntry(topic="events", data=_wire_bytes(b"item-1"))],
publisher_id="pub",
sequence=2,
),
)

await assert_eq_eventually(2, received_count)
assert [i.data for i in received] == [b"item-0", b"item-1"]
assert [i.offset for i in received] == [0, 1]
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await new_handle.signal(DrainingGateWorkflow.close)


# ---------------------------------------------------------------------------
# Cross-workflow workflow stream (Scenario 1)
# ---------------------------------------------------------------------------
Expand Down
Loading