Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions modules/python/clients/aks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,12 @@ def _begin_update_with_retry(
retries: int = 10,
retry_wait: int = 30,
poll_interval: int = 30,
timeout: int = 1200,
timeout: int = 1800,
) -> None:
"""
Call begin_create_or_update with retry on OperationNotAllowed/EtagMismatch,
polling every poll_interval seconds and raising TimeoutError after timeout seconds.
timeout defaults to 1800s (30 min) for slow GPU node provisioning (A100 MIG).
"""
for attempt in range(retries):
try:
Expand Down Expand Up @@ -363,6 +364,70 @@ def add_managed_gpu_node_pool(
)
logger.info(f"az aks nodepool add succeeded for '{node_pool_name}'")

@staticmethod
def _gpu_mode_metadata(
gpu_node_pool: bool,
enable_managed_gpu: bool,
gpu_instance_profile: Optional[str] = None,
gpu_mig_strategy: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build normalized GPU-mode metadata distinguishing managed vs fully-managed
GPU and MIG single vs mixed.

Derived from the operation INPUT flags rather than the AKS read-back: the
stable SDK does not model gpuProfile.nvidia.managementMode, so a
fully-managed pool's mode is dropped from nodepool_info. Flag combinations
are normalized for consistency: enable_managed_gpu / MIG only apply to a
GPU pool, and MIG only to fully-managed pools (dropped otherwise).

Returns a dict with gpu_mode ("none"|"managed"|"fully_managed"),
enable_managed_gpu, mig_enabled, gpu_instance_profile, gpu_mig_strategy.
Raises ValueError if gpu_mig_strategy is not None / "single" / "mixed".
"""
strategy = (gpu_mig_strategy or None) and str(gpu_mig_strategy).lower()
if strategy not in (None, "single", "mixed"):
raise ValueError(
f"invalid gpu_mig_strategy {gpu_mig_strategy!r} (want single/mixed/None)"
)

is_gpu = bool(gpu_node_pool)
fully_managed = is_gpu and bool(enable_managed_gpu)

if not is_gpu:
gpu_mode = "none"
elif fully_managed:
gpu_mode = "fully_managed"
else:
gpu_mode = "managed"

# MIG only applies to fully-managed pools; drop it otherwise.
profile = gpu_instance_profile if fully_managed else None
strategy = strategy if fully_managed else None

return {
"gpu_mode": gpu_mode,
"enable_managed_gpu": fully_managed,
"mig_enabled": bool(profile or strategy),
"gpu_instance_profile": profile,
"gpu_mig_strategy": strategy,
}

@staticmethod
def _log_gpu_mode(metadata: Dict[str, Any]) -> None:
"""Echo the normalized GPU-mode metadata to the console for traceability."""
if metadata.get("gpu_mode") in (None, "none"):
return
logger.info(
"GPU pool metadata: gpu_mode=%s enable_managed_gpu=%s mig_enabled=%s "
"gpu_instance_profile=%s gpu_mig_strategy=%s",
metadata.get("gpu_mode"),
metadata.get("enable_managed_gpu"),
metadata.get("mig_enabled"),
metadata.get("gpu_instance_profile"),
metadata.get("gpu_mig_strategy"),
)

def create_node_pool(
self,
node_pool_name: str,
Expand Down Expand Up @@ -409,8 +474,14 @@ def create_node_pool(
"vm_size": vm_size,
"node_count": node_count,
"gpu_node_pool": gpu_node_pool,
"enable_managed_gpu": enable_managed_gpu,
**self._gpu_mode_metadata(
gpu_node_pool,
enable_managed_gpu,
gpu_instance_profile,
gpu_mig_strategy,
),
}
self._log_gpu_mode(metadata)

# Create operation context to track the operation
with self._get_operation_context()(
Expand Down Expand Up @@ -522,6 +593,7 @@ def scale_node_pool(
progressive: bool = False,
scale_step_size: int = 1,
gpu_instance_profile: Optional[str] = None,
gpu_mig_strategy: Optional[str] = None,
) -> Any:
"""
Scale a node pool to the specified node count.
Expand Down Expand Up @@ -555,7 +627,14 @@ def scale_node_pool(
"gpu_node_pool": gpu_node_pool,
"progressive_scaling": progressive,
"scale_step_size": scale_step_size,
**self._gpu_mode_metadata(
gpu_node_pool,
enable_managed_gpu,
gpu_instance_profile,
gpu_mig_strategy,
),
}
self._log_gpu_mode(metadata)
node_pool = self.get_node_pool(node_pool_name, cluster_name)

current_count = node_pool.count
Expand Down Expand Up @@ -583,6 +662,7 @@ def scale_node_pool(
enable_managed_gpu=enable_managed_gpu,
node_pool=node_pool,
gpu_instance_profile=gpu_instance_profile,
gpu_mig_strategy=gpu_mig_strategy,
)

# Create operation context to track the operation
Expand Down Expand Up @@ -751,6 +831,7 @@ def _progressive_scale(
enable_managed_gpu: bool = False,
node_pool: Optional[Any] = None,
gpu_instance_profile: Optional[str] = None,
gpu_mig_strategy: Optional[str] = None,
) -> Any:
"""
Scale a node pool progressively with specified step size
Expand Down Expand Up @@ -816,7 +897,14 @@ def _progressive_scale(
"scale_step_size": scale_step_size,
"cluster_name": cluster_name or self.get_cluster_name(),
"gpu_node_pool": gpu_node_pool,
**self._gpu_mode_metadata(
gpu_node_pool,
enable_managed_gpu,
gpu_instance_profile,
gpu_mig_strategy,
),
}
self._log_gpu_mode(step_metadata)

# Create operation context for this specific step
with self._get_operation_context()(
Expand Down
25 changes: 18 additions & 7 deletions modules/python/clients/kubernetes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,18 +683,29 @@ def verify_nvidia_smi_on_node(self, nodes, namespace="default"):
logger.info(f"Verifying NVIDIA drivers on node {node_name}")
node = self.describe_node(node_name)

# Check if the node has GPUs allocated values (whole GPU or MIG slices)
# Wait for the node to advertise a POSITIVE GPU/MIG count. The device
# plugin can register nvidia.com/gpu with value "0" before MIG instances
# are published, so a MIG-single node briefly looks GPU-less. Waiting on
# key presence (rather than a positive count) would race in during that
# window and skip the node; wait on the count instead.
start_time = time.time()
gpu_count = 0
while time.time() < start_time + 600:
allocatable = node.status.allocatable or {}
if "nvidia.com/gpu" in allocatable or any(k.startswith("nvidia.com/mig-") for k in allocatable):
gpu_count = int(allocatable.get("nvidia.com/gpu", "0"))
mig_count = sum(
int(v) for k, v in allocatable.items()
if k.startswith("nvidia.com/mig-")
)
if gpu_count > 0 or mig_count > 0:
break
node = self.describe_node(node_name)
logger.info(f"Node allocatable resources: {node.status.allocatable}")
logger.info(f"Waiting for GPUs to be allocated on node {node_name}...")
logger.info(
f"Waiting for GPUs to be allocated on node {node_name}... "
f"(allocatable: {allocatable})"
)
time.sleep(1)
gpu_count = int(node.status.allocatable.get("nvidia.com/gpu", "0"))
has_mig = any(k.startswith("nvidia.com/mig-") for k in node.status.allocatable)
node = self.describe_node(node_name)
has_mig = any(k.startswith("nvidia.com/mig-") for k in (node.status.allocatable or {}))

logger.info(f"Node {node_name} has {gpu_count} GPUs, requesting all for validation")

Expand Down
5 changes: 5 additions & 0 deletions modules/python/crud/azure/node_pool_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def scale_node_pool(
gpu_node_pool=False,
enable_managed_gpu=False,
gpu_instance_profile=None,
gpu_mig_strategy=None,
):
"""
Scale a node pool to specified count
Expand Down Expand Up @@ -145,6 +146,7 @@ def scale_node_pool(
progressive=progressive,
scale_step_size=scale_step_size,
gpu_instance_profile=gpu_instance_profile,
gpu_mig_strategy=gpu_mig_strategy,
)

if result is not None:
Expand Down Expand Up @@ -253,6 +255,7 @@ def all(
gpu_node_pool=gpu_node_pool,
enable_managed_gpu=enable_managed_gpu,
gpu_instance_profile=gpu_instance_profile,
gpu_mig_strategy=gpu_mig_strategy,
)
results["scale_up"] = scale_up_result

Expand All @@ -276,6 +279,8 @@ def all(
scale_step_size=scale_step_size,
gpu_node_pool=gpu_node_pool,
enable_managed_gpu=enable_managed_gpu,
gpu_instance_profile=gpu_instance_profile,
gpu_mig_strategy=gpu_mig_strategy,
)
results["scale_down"] = scale_down_result

Expand Down
16 changes: 13 additions & 3 deletions modules/python/crud/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def handle_node_pool_operation(node_pool_crud, args):
command = args.command
result = None

# gpu_instance_profile / gpu_mig_strategy are Azure-only MIG inputs. The AWS
# CRUD does not accept these kwargs (and has no **kwargs), so passing them for
# --cloud aws would raise TypeError. Only forward them on Azure.
azure_gpu_kwargs = {}
if args.cloud == "azure":
azure_gpu_kwargs = {
"gpu_instance_profile": args.gpu_instance_profile,
"gpu_mig_strategy": args.gpu_mig_strategy,
}

try:
if command == "create":
# Prepare create arguments
Expand All @@ -116,8 +126,7 @@ def handle_node_pool_operation(node_pool_crud, args):
"node_count": args.node_count,
"gpu_node_pool": args.gpu_node_pool,
"enable_managed_gpu": args.enable_managed_gpu,
"gpu_instance_profile": args.gpu_instance_profile,
"gpu_mig_strategy": args.gpu_mig_strategy,
**azure_gpu_kwargs,
}

result = node_pool_crud.create_node_pool(**create_kwargs)
Expand All @@ -131,7 +140,7 @@ def handle_node_pool_operation(node_pool_crud, args):
"scale_step_size": args.scale_step_size,
"gpu_node_pool": args.gpu_node_pool,
"enable_managed_gpu": args.enable_managed_gpu,
"gpu_instance_profile": args.gpu_instance_profile,
**azure_gpu_kwargs,
}
Comment thread
xuexu6666 marked this conversation as resolved.

result = node_pool_crud.scale_node_pool(**scale_kwargs)
Expand All @@ -151,6 +160,7 @@ def handle_node_pool_operation(node_pool_crud, args):
"gpu_node_pool": args.gpu_node_pool,
"enable_managed_gpu": args.enable_managed_gpu,
"step_wait_time": args.step_wait_time,
**azure_gpu_kwargs,
}

result = node_pool_crud.all(**all_kwargs)
Expand Down
103 changes: 103 additions & 0 deletions modules/python/tests/clients/test_aks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,109 @@ def test_scale_gpu_node_pool_down_no_verification(self, mock_time):
# Check that NVIDIA verification was NOT performed for scale-down
self.mock_k8s.verify_nvidia_smi_on_node.assert_not_called()

def test_gpu_mode_metadata_variants(self):
"""_gpu_mode_metadata normalizes managed/fully-managed and MIG single/mixed."""
gpu_meta = AKSClient._gpu_mode_metadata # pylint: disable=protected-access
# Non-GPU pool
self.assertEqual(
gpu_meta(False, False),
{
"gpu_mode": "none",
"enable_managed_gpu": False,
"mig_enabled": False,
"gpu_instance_profile": None,
"gpu_mig_strategy": None,
},
)
# Managed (driver bootstrap only)
managed = gpu_meta(True, False)
self.assertEqual(managed["gpu_mode"], "managed")
self.assertFalse(managed["enable_managed_gpu"])
self.assertFalse(managed["mig_enabled"])
# Fully managed
fully = gpu_meta(True, True)
self.assertEqual(fully["gpu_mode"], "fully_managed")
self.assertTrue(fully["enable_managed_gpu"])
# Fully managed + MIG mixed
mixed = gpu_meta(True, True, "MIG1g", "mixed")
self.assertEqual(mixed["gpu_mode"], "fully_managed")
self.assertTrue(mixed["mig_enabled"])
self.assertEqual(mixed["gpu_instance_profile"], "MIG1g")
self.assertEqual(mixed["gpu_mig_strategy"], "mixed")
# Fully managed + MIG single
single = gpu_meta(True, True, "MIG1g", "single")
self.assertEqual(single["gpu_mig_strategy"], "single")
self.assertTrue(single["mig_enabled"])
# Normalization: MIG inputs are dropped for non-fully-managed pools
managed_with_mig = gpu_meta(True, False, "MIG1g", "single")
self.assertEqual(managed_with_mig["gpu_mode"], "managed")
self.assertFalse(managed_with_mig["mig_enabled"])
self.assertIsNone(managed_with_mig["gpu_instance_profile"])
self.assertIsNone(managed_with_mig["gpu_mig_strategy"])
# Normalization: managed flag is meaningless without a GPU pool
not_gpu = gpu_meta(False, True)
self.assertEqual(not_gpu["gpu_mode"], "none")
self.assertFalse(not_gpu["enable_managed_gpu"])
# Invalid MIG strategy is rejected
with self.assertRaises(ValueError):
gpu_meta(True, True, "MIG1g", "bogus")

def test_log_gpu_mode_console_echo(self):
"""_log_gpu_mode echoes GPU metadata to the console for GPU pools only."""
log_gpu_mode = AKSClient._log_gpu_mode # pylint: disable=protected-access
with self.assertLogs("clients.aks_client", level="INFO") as cm:
log_gpu_mode(
{
"gpu_mode": "fully_managed",
"enable_managed_gpu": True,
"mig_enabled": True,
"gpu_instance_profile": "MIG1g",
"gpu_mig_strategy": "mixed",
}
)
self.assertTrue(any("gpu_mode=fully_managed" in m for m in cm.output))
self.assertTrue(any("gpu_mig_strategy=mixed" in m for m in cm.output))
# Non-GPU operations must not emit the GPU metadata line.
with self.assertNoLogs("clients.aks_client", level="INFO"):
log_gpu_mode({"gpu_mode": "none"})

@mock.patch("clients.aks_client.time")
def test_scale_node_pool_records_gpu_mode_metadata(self, mock_time):
"""Scale ops persist gpu_mode + MIG fields even though the SDK read-back drops them."""
node_pool_name = "h100fullmgd"
node_count = 3

mock_time.time.side_effect = [100, 150]

mock_node_pool = mock.MagicMock()
mock_node_pool.count = 1
mock_node_pool.vm_size = "Standard_NC40ads_H100_v5"
mock_node_pool.as_dict.return_value = {"count": 1}
self.mock_agent_pools.get.return_value = mock_node_pool
self.aks_client.get_node_pool = mock.MagicMock(return_value=mock_node_pool)
self.mock_k8s.wait_for_nodes_ready.return_value = [mock.MagicMock()] * node_count
self.mock_k8s.verify_managed_gpu_systemd_services = mock.MagicMock(return_value={})
self.mock_k8s.verify_nvidia_smi_on_node = mock.MagicMock()
self.mock_k8s.verify_mig_allocatable = mock.MagicMock(return_value={})

result = self.aks_client.scale_node_pool(
node_pool_name=node_pool_name,
node_count=node_count,
gpu_node_pool=True,
enable_managed_gpu=True,
gpu_instance_profile="MIG1g",
gpu_mig_strategy="mixed",
)

self.assertTrue(result)
# The metadata dict is the 3rd positional arg to OperationContext(...)
metadata = self.mock_operation_context.call_args[0][2]
self.assertEqual(metadata["gpu_mode"], "fully_managed")
self.assertTrue(metadata["enable_managed_gpu"])
self.assertTrue(metadata["mig_enabled"])
self.assertEqual(metadata["gpu_instance_profile"], "MIG1g")
self.assertEqual(metadata["gpu_mig_strategy"], "mixed")


if __name__ == "__main__":
unittest.main()
Loading
Loading