diff --git a/temporalio/contrib/workflow_streams/_client.py b/temporalio/contrib/workflow_streams/_client.py index e28437e69..472a61b13 100644 --- a/temporalio/contrib/workflow_streams/_client.py +++ b/temporalio/contrib/workflow_streams/_client.py @@ -35,6 +35,7 @@ WorkflowHandle, WorkflowUpdateFailedError, WorkflowUpdateRPCTimeoutOrCancelledError, + WorkflowUpdateStage, ) from temporalio.converter import DataConverter, PayloadConverter from temporalio.service import RPCError, RPCStatusCode @@ -127,6 +128,10 @@ def __init__( self._pending_seq: int = 0 self._pending_since: float | None = None self._topic_types: dict[str, type[Any]] = {} + # Run id the most recent poll's update was admitted to. Captured before + # waiting for the outcome so a mid-poll continue-as-new can be detected by + # describing that specific run. None until the first poll is admitted. + self._polled_run_id: str | None = None @classmethod def create( @@ -528,11 +533,18 @@ async def subscribe( offset = from_offset while True: try: - result: PollResult = await self._handle.execute_update( + # Wait only for ACCEPTED so the handle (and the run id it was + # admitted to) is available before we block on the outcome; if + # the run continues-as-new mid-poll, result() fails but we still + # know which run to inspect. + handle = await self._handle.start_update( "__temporal_workflow_stream_poll", PollInput(topics=topic_filter, from_offset=offset), + wait_for_stage=WorkflowUpdateStage.ACCEPTED, result_type=PollResult, ) + self._polled_run_id = handle.workflow_run_id + result: PollResult = await handle.result() except asyncio.CancelledError: return except WorkflowUpdateFailedError as e: @@ -544,6 +556,14 @@ async def subscribe( # base_offset). offset = 0 continue + if cause_type == "StreamDraining": + # Workflow is detaching for continue-as-new. Back off and + # retry; the poll lands on the successor run once the + # rollover completes. + cooldown_secs = poll_cooldown.total_seconds() + if cooldown_secs > 0: + await asyncio.sleep(cooldown_secs) + continue if cause_type == "AcceptedUpdateCompletedWorkflow": # Workflow returned (or continued-as-new) before # this poll's update completed. Either follow the @@ -586,15 +606,32 @@ async def subscribe( if not result.more_ready and cooldown_secs > 0: await asyncio.sleep(cooldown_secs) + async def _describe_polled_run(self): + """Describe the specific run the most recent poll was admitted to. + + Describing that run (rather than the latest) is what lets a + continue-as-new be detected: a rolled-over run is closed with status + CONTINUED_AS_NEW, whereas the latest run would report RUNNING. Falls + back to the latest run when no run id has been captured yet, or when no + client is available to target a specific run. + """ + if self._client is not None: + return await self._client.get_workflow_handle( + self._workflow_id, run_id=self._polled_run_id + ).describe() + return await self._handle.describe() + async def _follow_continue_as_new(self) -> bool: - """Check if the workflow continued-as-new and re-target the handle. + """Check if the polled run continued-as-new and re-target the handle. - Returns True if the handle was updated (caller should retry). + Returns True if the handle was updated (caller should retry). The + successor run id is not needed — re-targeting to an unpinned handle + makes the next poll address the latest (successor) run. """ if self._client is None: return False try: - desc = await self._handle.describe() + desc = await self._describe_polled_run() except Exception: return False if desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW: @@ -603,14 +640,14 @@ async def _follow_continue_as_new(self) -> bool: return False async def _workflow_in_terminal_state(self) -> bool: - """Return True if the workflow has reached a terminal state. + """Return True if the polled run has reached a terminal state. Used by ``subscribe()`` to distinguish "workflow finished — stream is done" from "wrong workflow id" when a poll RPC returns NOT_FOUND. """ try: - desc = await self._handle.describe() + desc = await self._describe_polled_run() except Exception: return False return desc.status in ( diff --git a/temporalio/contrib/workflow_streams/_stream.py b/temporalio/contrib/workflow_streams/_stream.py index 2753f04c2..2d78f1fb5 100644 --- a/temporalio/contrib/workflow_streams/_stream.py +++ b/temporalio/contrib/workflow_streams/_stream.py @@ -460,9 +460,18 @@ async def _on_poll(self, payload: PollInput) -> PollResult: ) def _validate_poll(self, _payload: PollInput) -> None: - """Reject new polls when pollers are detached for continue-as-new.""" + """Reject new polls when pollers are detached for continue-as-new. + + Uses the well-known ``StreamDraining`` type so a subscriber recognizes + the rollover-in-progress and retries until its poll lands on the + successor run, rather than surfacing the rejection as an error. + """ if self._detaching: - raise RuntimeError("Workflow pollers are detached for continue-as-new") + raise ApplicationError( + "Workflow pollers are detached for continue-as-new", + type="StreamDraining", + non_retryable=True, + ) def _on_offset(self) -> int: """Return the current global offset (base_offset + log length).""" diff --git a/tests/contrib/workflow_streams/test_workflow_streams.py b/tests/contrib/workflow_streams/test_workflow_streams.py index 7353cbdd5..198fa214d 100644 --- a/tests/contrib/workflow_streams/test_workflow_streams.py +++ b/tests/contrib/workflow_streams/test_workflow_streams.py @@ -28,6 +28,7 @@ from temporalio import activity, nexus, workflow from temporalio.client import ( Client, + WorkflowExecutionStatus, WorkflowHandle, WorkflowUpdateFailedError, WorkflowUpdateStage, @@ -2008,6 +2009,175 @@ async def test_continue_as_new_helper(client: Client) -> None: await new_handle.signal(ContinueAsNewHelperWorkflow.close) +@pytest.mark.asyncio +async def test_follow_continue_as_new_describes_polled_run(client: Client) -> None: + """Regression test for continue-as-new detection. + + ``_follow_continue_as_new`` must describe the *specific run the poll was + admitted to* — a rolled-over run is closed with status CONTINUED_AS_NEW, + whereas the latest (successor) run reports RUNNING. The previous + implementation described the latest run, so the check never fired and a poll + failure during a rollover stopped the subscription instead of following it. + + Driving the exact poll-failure race deterministically is impractical (the + workflow drains in-flight polls before continuing-as-new), so this asserts + the helper's decision directly against a real post-rollover run. + """ + async with new_worker(client, ContinueAsNewHelperWorkflow) as worker: + handle = await client.start_workflow( + ContinueAsNewHelperWorkflow.run, + CANWorkflowInputTyped(), + id=f"workflow-stream-can-follow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))], + publisher_id="pub", + sequence=1, + ), + ) + old_run_id = handle.result_run_id + + await handle.signal(ContinueAsNewHelperWorkflow.trigger_continue) + new_handle = client.get_workflow_handle(handle.id) + await assert_eq_eventually(True, lambda: _is_different_run(handle, new_handle)) + + # The fix's premise: the polled (old) run reports CONTINUED_AS_NEW; the + # latest run reports RUNNING. + old_desc = await client.get_workflow_handle( + handle.id, run_id=old_run_id + ).describe() + assert old_desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW + latest_desc = await client.get_workflow_handle(handle.id).describe() + assert latest_desc.status == WorkflowExecutionStatus.RUNNING + + # The client follows the rollover when it describes the polled run, but + # the previous latest-run behavior (polled_run_id unset) would not. + following = WorkflowStreamClient.create(client, handle.id) + following._polled_run_id = old_run_id + assert await following._follow_continue_as_new() is True + + latest_only = WorkflowStreamClient.create(client, handle.id) + latest_only._polled_run_id = None # describes the latest run, as the bug did + assert await latest_only._follow_continue_as_new() is False + + await new_handle.signal(ContinueAsNewHelperWorkflow.close) + + +@workflow.defn +class DrainingGateWorkflow: + """CAN workflow that detaches pollers and then *holds* in the draining state + until released, so a subscriber deterministically hits the draining poll + rejection before the rollover completes.""" + + @workflow.init + def __init__(self, input: CANWorkflowInputTyped) -> None: + self.stream = WorkflowStream(prior_state=input.stream_state) + self._should_continue = False + self._release = False + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.signal + def trigger_continue(self) -> None: + self._should_continue = True + + @workflow.signal + def release(self) -> None: + self._release = True + + @workflow.run + async def run(self, _input: CANWorkflowInputTyped) -> None: + del _input + await workflow.wait_condition(lambda: self._should_continue or self._closed) + if self._closed: + return + # Detach but stay open until released, so new polls are rejected with + # StreamDraining for a deterministic window. + self.stream.detach_pollers() + await workflow.wait_condition(lambda: self._release) + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new( + args=[CANWorkflowInputTyped(stream_state=self.stream.get_state())] + ) + + +@pytest.mark.asyncio +async def test_subscribe_retries_while_draining(client: Client) -> None: + """A poll rejected because the stream is draining for continue-as-new must + be retried, not surfaced as an error: the subscription stays alive through + the rollover and resumes on the successor run.""" + async with new_worker(client, DrainingGateWorkflow) as worker: + handle = await client.start_workflow( + DrainingGateWorkflow.run, + CANWorkflowInputTyped(), + id=f"workflow-stream-draining-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))], + publisher_id="pub", + sequence=1, + ), + ) + + stream = WorkflowStreamClient.create(client, handle.id) + received: list[WorkflowStreamItem] = [] + + async def consume() -> None: + async for item in stream.subscribe( + from_offset=0, poll_cooldown=timedelta(0), result_type=bytes + ): + received.append(item) + + async def received_count() -> int: + return len(received) + + task = asyncio.create_task(consume()) + new_handle = client.get_workflow_handle(handle.id) + try: + await assert_eq_eventually(1, received_count) + + # Detach; the subscriber's polls are now rejected with StreamDraining. + await handle.signal(DrainingGateWorkflow.trigger_continue) + # The subscription must keep retrying, not error out. + await asyncio.sleep(1.0) + assert not task.done(), "draining rejection must not end the subscription" + + # Release: the workflow continues-as-new; the subscription resumes on + # the successor run and receives an item published there. + await handle.signal(DrainingGateWorkflow.release) + await assert_eq_eventually( + True, lambda: _is_different_run(handle, new_handle) + ) + await new_handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-1"))], + publisher_id="pub", + sequence=2, + ), + ) + + await assert_eq_eventually(2, received_count) + assert [i.data for i in received] == [b"item-0", b"item-1"] + assert [i.offset for i in received] == [0, 1] + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + await new_handle.signal(DrainingGateWorkflow.close) + + # --------------------------------------------------------------------------- # Cross-workflow workflow stream (Scenario 1) # ---------------------------------------------------------------------------