Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions temporalio/client/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,17 @@ class ActivityExecutionDescription(ActivityExecution):
long_poll_token: bytes | None
"""Token for follow-on long-poll requests. None if the activity is complete."""

raw_callbacks: Sequence[temporalio.api.activity.v1.CallbackInfo]
"""Underlying protobuf callbacks"""

@classmethod
async def _from_execution_info(
cls,
info: temporalio.api.activity.v1.ActivityExecutionInfo,
long_poll_token: bytes | None,
namespace: str,
data_converter: temporalio.converter.DataConverter,
callbacks: Sequence[temporalio.api.activity.v1.CallbackInfo],
) -> Self:
"""Create from raw proto activity execution info."""
# Decode heartbeat details if present
Expand Down Expand Up @@ -409,6 +413,7 @@ async def _from_execution_info(
typed_search_attributes=temporalio.converter.decode_typed_search_attributes(
info.search_attributes
),
raw_callbacks=callbacks,
)


Expand Down Expand Up @@ -691,6 +696,8 @@ def __init__(
*,
run_id: str | None = None,
result_type: type | None = None,
start_activity_response: None
| temporalio.api.workflowservice.v1.StartActivityExecutionResponse = None,
) -> None:
"""Create activity handle."""
self._client = client
Expand All @@ -700,6 +707,7 @@ def __init__(
self._known_outcome: (
temporalio.api.activity.v1.ActivityExecutionOutcome | None
) = None
self._start_activity_response = start_activity_response

@functools.cached_property
def _data_converter(self) -> temporalio.converter.DataConverter:
Expand Down
9 changes: 9 additions & 0 deletions temporalio/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,12 @@ async def start_activity(
start_delay: timedelta | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
# The following options should not be considered part of the public API. They
# are deliberately not exposed in overloads, and are not subject to any
# backwards compatibility guarantees.
callbacks: Sequence[Callback] = [],
links: Sequence[temporalio.api.common.v1.Link] = [],
request_id: str | None = None,
) -> ActivityHandle[ReturnType]:
"""Start an activity and return its handle.

Expand Down Expand Up @@ -1542,6 +1548,9 @@ async def start_activity(
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
priority=priority,
callbacks=callbacks,
links=links,
request_id=request_id,
)
)

Expand Down
23 changes: 22 additions & 1 deletion temporalio/client/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def _build_start_workflow_execution_request(
# Links are duplicated on request for compatibility with older server versions.
req.links.extend(links)

if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
if temporalio.nexus._operation_context._in_nexus_backing_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True
Expand Down Expand Up @@ -567,6 +567,7 @@ async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]
input.id,
run_id=resp.run_id,
result_type=input.result_type,
start_activity_response=resp,
)

async def _build_start_activity_execution_request(
Expand Down Expand Up @@ -610,6 +611,8 @@ async def _build_start_activity_execution_request(
),
)

if input.request_id:
req.request_id = input.request_id
if input.schedule_to_close_timeout is not None:
req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout)
if input.start_to_close_timeout is not None:
Expand Down Expand Up @@ -645,6 +648,23 @@ async def _build_start_activity_execution_request(
# Set priority
req.priority.CopyFrom(input.priority._to_proto())

req.completion_callbacks.extend(
temporalio.api.common.v1.Callback(
nexus=temporalio.api.common.v1.Callback.Nexus(
url=callback.url,
header=callback.headers,
),
links=input.links,
)
for callback in input.callbacks
)
req.links.extend(input.links)

if temporalio.nexus._operation_context._in_nexus_backing_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True

return req

async def cancel_activity(self, input: CancelActivityInput) -> None:
Expand Down Expand Up @@ -703,6 +723,7 @@ async def describe_activity(
workflow_id=input.activity_id, # Using activity_id as workflow_id for activities not started by a workflow
)
),
callbacks=resp.callbacks,
)

def list_activities(
Expand Down
4 changes: 4 additions & 0 deletions temporalio/client/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ class StartActivityInput:
headers: Mapping[str, temporalio.api.common.v1.Payload]
rpc_metadata: Mapping[str, str | bytes]
rpc_timeout: timedelta | None
# The following options are experimental and unstable.
callbacks: Sequence[Callback]
links: Sequence[temporalio.api.common.v1.Link]
request_id: str | None


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
wait_for_worker_shutdown_sync,
)
from ._operation_handlers import (
CancelActivityOptions,
CancelWorkflowRunOptions,
TemporalOperationHandler,
)
Expand All @@ -33,6 +34,7 @@

__all__ = (
"workflow_run_operation",
"CancelActivityOptions",
"CancelWorkflowRunOptions",
"Info",
"LoggerAdapter",
Expand Down
82 changes: 54 additions & 28 deletions temporalio/nexus/_link_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
logger = logging.getLogger(__name__)

_NEXUS_OPERATION_LINK_URL_PATH_REGEX = re.compile(
r"^/namespaces/(?P<namespace>[^/]+)/nexus-operations/(?P<operation_id>[^/]+)$"
r"^/namespaces/(?P<namespace>[^/]+)/nexus-operations/(?P<operation_id>[^/]+)/(?P<run_id>[^/]+)/details$"
)

_ACTIVITY_LINK_URL_PATH_REGEX = re.compile(
r"^/namespaces/(?P<namespace>[^/]+)/activities/(?P<activity_id>[^/]+)/(?P<run_id>[^/]+)/details$"
)

_WORFKLOW_LINK_URL_PATH_REGEX = re.compile(
Expand All @@ -31,13 +35,13 @@
class _LinkType(str, Enum):
WORKFLOW = temporalio.api.common.v1.Link.WorkflowEvent.DESCRIPTOR.full_name
NEXUS_OPERATION = temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name
ACTIVITY = temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name


LINK_EVENT_ID_PARAM_NAME = "eventID"
LINK_EVENT_TYPE_PARAM_NAME = "eventType"
LINK_REQUEST_ID_PARAM_NAME = "requestID"
LINK_REFERENCE_TYPE_PARAM_NAME = "referenceType"
LINK_RUN_ID_PARAM_NAME = "runID"

EVENT_REFERENCE_TYPE = "EventReference"
REQUEST_ID_REFERENCE_TYPE = "RequestIdReference"
Expand Down Expand Up @@ -84,6 +88,9 @@ def nexus_link_to_temporal_link(
case _LinkType.NEXUS_OPERATION:
return nexus_link_to_nexus_operation_link(nexus_link)

case _LinkType.ACTIVITY:
return nexus_link_to_activity_link(nexus_link)


def temporal_link_to_nexus_link(
temporal_link: temporalio.api.common.v1.Link,
Expand All @@ -99,8 +106,11 @@ def temporal_link_to_nexus_link(
case "nexus_operation":
return nexus_operation_to_nexus_link(temporal_link.nexus_operation)

case "activity" | "batch_job":
raise NotImplementedError("only workflow links are supported")
case "activity":
return activity_link_to_nexus_link(temporal_link.activity)

case "batch_job":
raise NotImplementedError("batch_job links are not supported")

case None:
logger.warning("Invalid Temporal link: missing variant")
Expand Down Expand Up @@ -149,22 +159,30 @@ def nexus_operation_to_nexus_link(
scheme = "temporal"
namespace = urllib.parse.quote(op_link.namespace, safe="")
operation_id = urllib.parse.quote(op_link.operation_id, safe="")
path = f"/namespaces/{namespace}/nexus-operations/{operation_id}"

query_params = ""
if op_link.run_id:
query_params = urllib.parse.urlencode(
{
LINK_RUN_ID_PARAM_NAME: op_link.run_id,
},
)
run_id = urllib.parse.quote(op_link.run_id, safe="")
path = f"/namespaces/{namespace}/nexus-operations/{operation_id}/{run_id}/details"

# urllib will omit '//' from the url if netloc is empty so we add the scheme manually
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}"
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}"

return nexusrpc.Link(url=url, type=_LinkType.NEXUS_OPERATION.value)


def activity_link_to_nexus_link(
activity: temporalio.api.common.v1.Link.Activity,
) -> nexusrpc.Link:
"""Convert an Activity link into a nexusrpc link."""
scheme = "temporal"
namespace = urllib.parse.quote(activity.namespace, safe="")
activity_id = urllib.parse.quote(activity.activity_id, safe="")
run_id = urllib.parse.quote(activity.run_id, safe="")
path = f"/namespaces/{namespace}/activities/{activity_id}/{run_id}/details"

url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}"

return nexusrpc.Link(url=url, type=_LinkType.ACTIVITY.value)


def nexus_link_to_workflow_event_link(
link: nexusrpc.Link,
) -> temporalio.api.common.v1.Link | None:
Expand Down Expand Up @@ -228,28 +246,36 @@ def nexus_link_to_nexus_operation_link(
)
return None

query_params = urllib.parse.parse_qs(url.query)

match query_params.get(LINK_RUN_ID_PARAM_NAME):
case [run_id_param]:
run_id = run_id_param
case [] | None:
run_id = ""
case _:
logger.warning(
f"Invalid Nexus link: {nexus_link}. Expected {LINK_RUN_ID_PARAM_NAME} to have at most 1 value"
)
return None

groups = match.groupdict()
nexus_op_link = temporalio.api.common.v1.Link.NexusOperation(
namespace=urllib.parse.unquote(groups["namespace"]),
operation_id=urllib.parse.unquote(groups["operation_id"]),
run_id=run_id,
run_id=urllib.parse.unquote(groups["run_id"]),
)
return temporalio.api.common.v1.Link(nexus_operation=nexus_op_link)


def nexus_link_to_activity_link(
nexus_link: nexusrpc.Link,
) -> temporalio.api.common.v1.Link | None:
"""Convert a nexus link into a Temporal Activity link."""
url = urllib.parse.urlparse(nexus_link.url)
match = _ACTIVITY_LINK_URL_PATH_REGEX.match(url.path)
if not match:
logger.warning(
f"Invalid Nexus link: {nexus_link}. Expected path to match {_ACTIVITY_LINK_URL_PATH_REGEX.pattern}"
)
return None

groups = match.groupdict()
activity_link = temporalio.api.common.v1.Link.Activity(
namespace=urllib.parse.unquote(groups["namespace"]),
activity_id=urllib.parse.unquote(groups["activity_id"]),
run_id=urllib.parse.unquote(groups["run_id"]),
)
return temporalio.api.common.v1.Link(activity=activity_link)


def _event_reference_to_query_params(
event_ref: temporalio.api.common.v1.Link.WorkflowEvent.EventReference,
) -> str:
Expand Down
Loading
Loading