Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion modules/python/clients/aks_machine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Set

Expand Down Expand Up @@ -469,7 +470,8 @@ def scale_machine(
"""Scale a Machine-mode agentpool by creating ``scale_machine_count`` machines.

Opens an ``OperationContext`` here, enriches with successful machine
names, per-percentile readiness envelope, and the cluster snapshot.
names, per-batch command execution times, per-percentile readiness
envelope, and the cluster snapshot.
Returns on success; re-raises on failure so the context records the
trace.

Expand Down Expand Up @@ -516,6 +518,8 @@ def scale_machine(
machine_workers=machine_workers,
timeout=timeout,
readiness_wait_timeout=readiness_wait_timeout,
batch_command_execution_times={},
batch_command_execution_times_lock=threading.Lock(),
)
with self._get_operation_context()(
"scale_machine", "azure", metadata, result_dir=self.result_dir
Expand Down Expand Up @@ -548,6 +552,11 @@ def scale_machine(
else:
successful = self._scale_machine_individually(request, names)
op.add_metadata("command_execution_time", time.time() - command_t0)
if use_batch_api:
op.add_metadata(
"batch_command_execution_times",
dict(request.batch_command_execution_times),
)
op.add_metadata("successful_machines", len(successful))

# Fail fast on partial landing BEFORE waiting on the agentpool
Expand Down Expand Up @@ -905,6 +914,7 @@ def _create_batch_machines(
# but it does not bound total wall time for the chunk because retry
# and backoff handling can still extend the overall elapsed time.
put_timeout = min(request.timeout, _PER_REQUEST_TIMEOUT_CAP)
start_time = datetime.now(timezone.utc)
self._make_batch_request(
"PUT",
url,
Expand All @@ -914,6 +924,31 @@ def _create_batch_machines(
chunk_idx=chunk_idx,
first_machine_name=first_machine_name,
)
end_time = datetime.now(timezone.utc)
execution_time_seconds = (end_time - start_time).total_seconds()
batch_command_execution_times = getattr(
request, "batch_command_execution_times", None
)
if batch_command_execution_times is None:
batch_command_execution_times = {}
request.batch_command_execution_times = batch_command_execution_times
metric = {
"start_time": start_time.isoformat().replace("+00:00", "Z"),
"end_time": end_time.isoformat().replace("+00:00", "Z"),
"execution_time_seconds": execution_time_seconds,
"total_machines_in_batch": len(chunk),
}
Comment thread
karenychen marked this conversation as resolved.
Outdated
Comment on lines +909 to +915

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should record batch timing even when the BatchPutMachine request fails.

Right now the metric is written only after _make_batch_request() returns successfully. If it raises after retries / timeout / non-2xx, _scale_machine_batch() catches that worker failure and returns [], so the operation only reports landed X/Y machines and the failed batch has no timing entry. That makes the new metric biased toward successful batches and hides the slow/failing batch we most need for debugging.

Could we wrap the request in try/finally so every batch records elapsed time, and include a small status/error field?

metric_lock = getattr(request, "batch_command_execution_times_lock", None)
if metric_lock is None:
batch_command_execution_times[first_machine_name] = metric
else:
with metric_lock:
batch_command_execution_times[first_machine_name] = metric
logger.info(
f"chunk {chunk_idx}: BatchPutMachine PUT completed for {len(chunk)} "
f"machines in {execution_time_seconds:.3f} seconds "
f"(target={first_machine_name})"
)
return list(chunk)

def _make_batch_request(
Expand Down
29 changes: 26 additions & 3 deletions modules/python/tests/test_aks_machine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,25 @@ def test_scale_machine_consumes_timeout_budgets(self):

def test_scale_machine_batch_path_dispatches_to_batch(self):
"""The use_batch_api=True branch calls _scale_machine_batch (not _individually)
and records command_execution_time in operation metadata."""
and records batch command timings in operation metadata."""
names = [f"scale2-machine-{i+1}" for i in range(2)]
batch_metrics = {
"scale2-machine-1": {
"start_time": "2026-06-23T00:00:00Z",
"end_time": "2026-06-23T00:00:01Z",
"execution_time_seconds": 1.0,
"total_machines_in_batch": 2,
}
}

def fake_batch(request, batch_names):
request.batch_command_execution_times.update(batch_metrics)
return batch_names

with mock.patch.object(
AKSMachineClient,
"_scale_machine_batch",
return_value=names,
side_effect=fake_batch,
) as mock_batch, mock.patch.object(
AKSMachineClient, "_scale_machine_individually"
) as mock_individual, mock.patch.object(
Expand Down Expand Up @@ -333,11 +346,14 @@ def test_scale_machine_batch_path_dispatches_to_batch(self):
)
mock_batch.assert_called_once()
mock_individual.assert_not_called()
# command_execution_time must be added to operation metadata.
# Command timings must be added to operation metadata.
added_keys = {
call.args[0] for call in self.mock_operation.add_metadata.call_args_list
}
self.assertIn("command_execution_time", added_keys)
self.mock_operation.add_metadata.assert_any_call(
"batch_command_execution_times", batch_metrics
)

# ---- _create_single_machine ----

Expand Down Expand Up @@ -868,6 +884,13 @@ def test_create_batch_machines_header_shape(self):
)
self.assertEqual(result, ["m-1", "m-2", "m-3"])
mock_make_batch.assert_called_once()
self.assertEqual(set(request.batch_command_execution_times), {"m-1"})
metric = request.batch_command_execution_times["m-1"]
self.assertIn("start_time", metric)
self.assertIn("end_time", metric)
self.assertIsInstance(metric["execution_time_seconds"], float)
self.assertGreaterEqual(metric["execution_time_seconds"], 0.0)
self.assertEqual(metric["total_machines_in_batch"], 3)
# Inspect the batch_header_value kwarg.
import json as _json # pylint: disable=import-outside-toplevel
kwargs = mock_make_batch.call_args.kwargs
Expand Down
Loading