Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions modules/python/clients/aks_machine_client.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
"""AKS Machine API client.

Subclasses ``AKSClient`` to inherit auth, ContainerServiceClient,
KubernetesClient, ``get_cluster_name``, ``get_cluster_data``, and the existing
node-pool CRUD methods. Adds raw REST methods for the Machine API which is not
yet exposed in the Azure SDK.

Public methods (``create_machine_agentpool``, ``scale_machine``) wrap their work
in ``OperationContext``: the context is opened inside the client method,
metadata is enriched with ``op.add_metadata`` along the way, success returns
None, failures are logged and re-raised so ``OperationContext`` records them.
``MachineCRUD`` therefore stays a thin try/except wrapper.

The non-batch scale path PUTs machines one at a time. The batch dispatch
(``use_batch_api=True``) uses the private ``BatchPutMachine`` header contract.
Extends ``AKSClient`` with raw Machine API REST methods. Public methods wrap
work in ``OperationContext``; ``MachineCRUD`` stays a thin try/except wrapper.
The batch scale path uses the private ``BatchPutMachine`` header contract.
"""
import json
import logging
import math
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Set

Expand Down Expand Up @@ -469,7 +460,7 @@ def scale_machine(
"""Scale a Machine-mode agentpool by creating ``scale_machine_count`` machines.

Opens an ``OperationContext`` here, enriches with successful machine
names, per-percentile readiness envelope, and the cluster snapshot.
names, batch timings, readiness metadata, and the cluster snapshot.
Returns on success; re-raises on failure so the context records the
trace.

Expand Down Expand Up @@ -511,11 +502,9 @@ def scale_machine(
cluster_name=cluster_name,
resource_group=self.resource_group,
vm_size=vm_size,
scale_machine_count=scale_machine_count,
use_batch_api=use_batch_api,
machine_workers=machine_workers,
timeout=timeout,
readiness_wait_timeout=readiness_wait_timeout,
batch_command_execution_times={},
)
with self._get_operation_context()(
"scale_machine", "azure", metadata, result_dir=self.result_dir
Expand Down Expand Up @@ -548,6 +537,11 @@ def scale_machine(
else:
successful = self._scale_machine_individually(request, names)
op.add_metadata("command_execution_time", time.time() - command_t0)
if use_batch_api:
op.add_metadata(
"batch_command_execution_times",
dict(request.batch_command_execution_times),
)
op.add_metadata("successful_machines", len(successful))

# Fail fast on partial landing BEFORE waiting on the agentpool
Expand Down Expand Up @@ -795,9 +789,8 @@ def _scale_machine_batch(
``_BATCH_MAX_MACHINES_PER_REQUEST``. ``ValueError`` is raised otherwise
so the failure surfaces before any partial batch submission reaches ARM.

Per-worker exceptions are caught and logged; the worker's slice is
excluded from the returned ``successful`` list. The input ``request``
is not mutated.
Per-worker exceptions are caught and logged; failed slices are excluded
from ``successful`` and batch timings are recorded on ``request``.
"""
n = len(names)
workers = request.machine_workers
Expand Down Expand Up @@ -905,15 +898,26 @@ def _create_batch_machines(
# but it does not bound total wall time for the chunk because retry
# and backoff handling can still extend the overall elapsed time.
put_timeout = min(request.timeout, _PER_REQUEST_TIMEOUT_CAP)
start_time = datetime.now(timezone.utc)
self._make_batch_request(
"PUT",
url,
body,
put_timeout,
"PUT", url, body, put_timeout,
batch_header_value=batch_header_value,
chunk_idx=chunk_idx,
first_machine_name=first_machine_name,
)
end_time = datetime.now(timezone.utc)
execution_time_seconds = (end_time - start_time).total_seconds()
request.batch_command_execution_times[first_machine_name] = {
"start_time": start_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"end_time": end_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"execution_time_seconds": execution_time_seconds,
"total_machines_in_batch": len(chunk),
}
Comment on lines +909 to +915

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should record batch timing even when the BatchPutMachine request fails.

Right now the metric is written only after _make_batch_request() returns successfully. If it raises after retries / timeout / non-2xx, _scale_machine_batch() catches that worker failure and returns [], so the operation only reports landed X/Y machines and the failed batch has no timing entry. That makes the new metric biased toward successful batches and hides the slow/failing batch we most need for debugging.

Could we wrap the request in try/finally so every batch records elapsed time, and include a small status/error field?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch PutMachine command execution time is the cx-observed latency in frontend. If this step fails, we would be more interested in the error than the latency and will be manually looking into internal logs for them.

logger.info(
f"chunk {chunk_idx}: BatchPutMachine PUT completed for {len(chunk)} "
f"machines in {execution_time_seconds:.3f} seconds "
f"(target={first_machine_name})"
)
return list(chunk)

def _make_batch_request(
Expand Down
55 changes: 29 additions & 26 deletions modules/python/tests/test_aks_machine_client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
#!/usr/bin/env python3
"""Unit tests for AKSMachineClient (subclass of AKSClient).

All public methods open an ``OperationContext`` (patched here) and write the
result file via ``Operation.save_to_file`` on context exit. Tests verify:
- success path completes without raising and enriches ``op.add_metadata`` with
the right keys
- failure path raises (so the OperationContext records ``success=False``)

Tests also cover both scale paths and the batch-path helpers.
"""
"""Unit tests for AKSMachineClient."""
# pylint: disable=protected-access
# Tests intentionally exercise private helpers directly; the leading underscore
# is conventional rather than semantic privacy.
import itertools
import tempfile
import unittest
Expand Down Expand Up @@ -46,14 +35,10 @@ def setUp(self):

self.mock_k8s = self.mock_k8s_class.return_value
self.mock_operation = mock.MagicMock()
self.mock_operation_context.return_value.__enter__.return_value = (
self.mock_operation
)
self.mock_operation_context.return_value.__enter__.return_value = self.mock_operation
self.mock_operation_context.return_value.__exit__.return_value = None

# Hermetic per-test temp dir avoids cross-platform /tmp assumptions
# and parallel-run collisions. ``with`` doesn't fit the setUp/tearDown
# lifecycle, so we explicitly cleanup() in tearDown.
# TemporaryDirectory does not fit the setUp/tearDown lifecycle.
self._tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.test_result_dir = self._tmp_dir.name

Expand All @@ -67,9 +52,7 @@ def setUp(self):

# Stub inherited helpers that the Machine methods enrich metadata with.
self.client.get_cluster_name = mock.MagicMock(return_value="fake-cluster")
self.client.get_cluster_data = mock.MagicMock(
return_value={"name": "fake-cluster"}
)
self.client.get_cluster_data = mock.MagicMock(return_value={"name": "fake-cluster"})
fake_pool = mock.MagicMock()
fake_pool.as_dict.return_value = {"name": "fake-pool"}
self.client.get_node_pool = mock.MagicMock(return_value=fake_pool)
Expand Down Expand Up @@ -294,18 +277,27 @@ def test_scale_machine_consumes_timeout_budgets(self):

request = mock_individual.call_args.args[0]
self.assertEqual(request.timeout, 900)
self.assertEqual(request.readiness_wait_timeout, 900)
self.assertEqual(mock_wait_ap.call_args.args[1], 900)
self.assertEqual(mock_wait_ready.call_args.kwargs["timeout"], 900)

def test_scale_machine_batch_path_dispatches_to_batch(self):
"""The use_batch_api=True branch calls _scale_machine_batch (not _individually)
and records command_execution_time in operation metadata."""
names = [f"scale2-machine-{i+1}" for i in range(2)]
and records batch command timings in operation metadata."""
batch_metric = {
"start_time": "2026-06-23T00:00:00Z",
"end_time": "2026-06-23T00:00:01Z",
"execution_time_seconds": 1.0,
"total_machines_in_batch": 2,
}

def fake_batch(request, batch_names):
request.batch_command_execution_times["scale2-machine-1"] = batch_metric
return batch_names

with mock.patch.object(
AKSMachineClient,
"_scale_machine_batch",
return_value=names,
side_effect=fake_batch,
) as mock_batch, mock.patch.object(
AKSMachineClient, "_scale_machine_individually"
) as mock_individual, mock.patch.object(
Expand Down Expand Up @@ -333,11 +325,14 @@ def test_scale_machine_batch_path_dispatches_to_batch(self):
)
mock_batch.assert_called_once()
mock_individual.assert_not_called()
# command_execution_time must be added to operation metadata.
# Command timings must be added to operation metadata.
added_keys = {
call.args[0] for call in self.mock_operation.add_metadata.call_args_list
}
self.assertIn("command_execution_time", added_keys)
self.mock_operation.add_metadata.assert_any_call(
"batch_command_execution_times", {"scale2-machine-1": batch_metric}
)

# ---- _create_single_machine ----

Expand Down Expand Up @@ -859,6 +854,7 @@ def test_create_batch_machines_header_shape(self):
resource_group="fake-rg",
vm_size="Standard_D2_v3",
timeout=60,
batch_command_execution_times={},
)
with mock.patch.object(
AKSMachineClient, "_make_batch_request"
Expand All @@ -868,6 +864,13 @@ def test_create_batch_machines_header_shape(self):
)
self.assertEqual(result, ["m-1", "m-2", "m-3"])
mock_make_batch.assert_called_once()
self.assertEqual(set(request.batch_command_execution_times), {"m-1"})
metric = request.batch_command_execution_times["m-1"]
self.assertRegex(metric["start_time"], r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$")
self.assertRegex(metric["end_time"], r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$")
self.assertIsInstance(metric["execution_time_seconds"], float)
self.assertGreaterEqual(metric["execution_time_seconds"], 0.0)
self.assertEqual(metric["total_machines_in_batch"], 3)
# Inspect the batch_header_value kwarg.
import json as _json # pylint: disable=import-outside-toplevel
kwargs = mock_make_batch.call_args.kwargs
Expand Down
Loading