Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions temporalio/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,8 @@ def on_start_error(
input = StartWorkflowUpdateWithStartInput(
start_workflow_input=start_workflow_operation._start_workflow_input,
update_workflow_input=update_input,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
_on_start=on_start,
_on_start_error=on_start_error,
)
Expand Down
15 changes: 13 additions & 2 deletions temporalio/client/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,11 @@ def on_start(

try:
return await self._start_workflow_update_with_start(
input.start_workflow_input, input.update_workflow_input, on_start
input.start_workflow_input,
input.update_workflow_input,
input.rpc_metadata,
input.rpc_timeout,
on_start,
)
except asyncio.CancelledError as _err:
err = _err
Expand Down Expand Up @@ -914,6 +918,8 @@ async def _start_workflow_update_with_start(
self,
start_input: UpdateWithStartStartWorkflowInput,
update_input: UpdateWithStartUpdateWorkflowInput,
rpc_metadata: Mapping[str, str | bytes],
rpc_timeout: timedelta | None,
on_start: Callable[
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
],
Expand Down Expand Up @@ -941,7 +947,12 @@ async def _start_workflow_update_with_start(
# Repeatedly try to invoke ExecuteMultiOperation until the update is durable
while True:
multiop_response = (
await self._client.workflow_service.execute_multi_operation(multiop_req)
await self._client.workflow_service.execute_multi_operation(
multiop_req,
retry=True,
metadata=rpc_metadata,
timeout=rpc_timeout,
)
)
start_response = multiop_response.responses[0].start_workflow
update_response = multiop_response.responses[1].update_workflow
Expand Down
12 changes: 11 additions & 1 deletion temporalio/client/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,20 @@ class UpdateWithStartStartWorkflowInput:

@dataclass
class StartWorkflowUpdateWithStartInput:
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.

The top-level ``rpc_metadata`` and ``rpc_timeout`` fields are authoritative
for the ``execute_multi_operation`` gRPC call. The sub-inputs
(``start_workflow_input`` and ``update_workflow_input``) also carry their own
``rpc_metadata`` / ``rpc_timeout`` for interceptor introspection, but those
values are **not** forwarded to the gRPC call. Interceptors that wish to set
RPC metadata should modify :py:attr:`rpc_metadata` on this object.
"""

start_workflow_input: UpdateWithStartStartWorkflowInput
update_workflow_input: UpdateWithStartUpdateWorkflowInput
rpc_metadata: Mapping[str, str | bytes]
rpc_timeout: timedelta | None
_on_start: Callable[
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
]
Expand Down
81 changes: 81 additions & 0 deletions tests/worker/test_update_with_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,3 +1104,84 @@ async def _do_update() -> Any:
elif id_reuse_policy == WorkflowIDReusePolicy.REJECT_DUPLICATE:
with pytest.raises(WorkflowAlreadyStartedError):
await _do_update()


class MetadataCapturingInterceptor(Interceptor):
"""Interceptor that sets rpc_metadata on update-with-start calls."""

def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
return MetadataCapturingOutboundInterceptor(super().intercept_client(next))


class MetadataCapturingOutboundInterceptor(OutboundInterceptor):
def __init__(self, next: OutboundInterceptor) -> None:
super().__init__(next)

async def start_update_with_start_workflow(
self, input: StartWorkflowUpdateWithStartInput
) -> WorkflowUpdateHandle[Any]:
input.rpc_metadata = {
**input.rpc_metadata,
"test-header-key": "test-header-value",
}
return await super().start_update_with_start_workflow(input)


# Verify fix for https://github.com/temporalio/sdk-python/issues/1582
async def test_update_with_start_rpc_metadata_and_timeout_forwarded(client: Client):
"""Test that rpc_metadata and rpc_timeout on StartWorkflowUpdateWithStartInput
are forwarded to the execute_multi_operation gRPC call."""
captured_metadata: dict[str, str | bytes] = {}
captured_timeout: list[timedelta | None] = []

class execute_multi_operation:
err = RPCError("intentional", RPCStatusCode.INTERNAL, b"")
err._grpc_status = temporalio.api.common.v1.GrpcStatus(details=[])

def __init__(self) -> None: # type: ignore[reportMissingSuperCall]
pass

async def __call__(
self,
req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest,
*,
retry: bool = False,
metadata: Mapping[str, str | bytes] = {},
timeout: timedelta | None = None,
) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse:
captured_metadata.update(metadata)
captured_timeout.append(timeout)
raise self.err

interceptor = MetadataCapturingInterceptor()
intercepted_client = Client(
**{**client.config(), "interceptors": [interceptor]} # type: ignore
)

with patch.object(
intercepted_client.workflow_service,
"execute_multi_operation",
execute_multi_operation(),
):
start_workflow_operation = WithStartWorkflowOperation(
UpdateWithStartInterceptorWorkflow.run,
"wf-arg",
id=f"wf-{uuid.uuid4()}",
task_queue="tq",
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
with pytest.raises(RPCError):
await intercepted_client.start_update_with_start_workflow(
UpdateWithStartInterceptorWorkflow.my_update,
"update-arg",
start_workflow_operation=start_workflow_operation,
wait_for_stage=WorkflowUpdateStage.ACCEPTED,
rpc_metadata={"original-key": "original-value"},
rpc_timeout=timedelta(seconds=42),
)

# The interceptor should have added its metadata on top of the caller's
assert captured_metadata.get("test-header-key") == "test-header-value"
assert captured_metadata.get("original-key") == "original-value"
# The caller's timeout should have been forwarded
assert captured_timeout == [timedelta(seconds=42)]