Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ansible-runner/project
48 changes: 48 additions & 0 deletions modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,63 @@
"""

import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import click
from loguru import logger
import ansible_runner

from .utils.graphql import GraphQlClient

# Define context settings to support -h for help across all commands
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

COACT_ANSIBLE_RUNNER_PATH = './ansible-runner/'


class AnsibleRunner:
"""Mixin class for running Ansible playbooks."""
# Using loguru logger
ident = None

def run_playbook(
self,
playbook: str,
private_data_dir: str = COACT_ANSIBLE_RUNNER_PATH,
tags: str = 'all',
dry_run: bool = False,
**kwargs
) -> Optional[ansible_runner.runner.Runner]:
name = Path(playbook).name
if not dry_run:
r = ansible_runner.run(
private_data_dir=private_data_dir,
playbook=playbook,
tags=tags,
extravars=kwargs,
suppress_env_files=True,
ident=f'{self.ident}_{name}:{tags}',
cancel_callback=lambda: None
)
self.logger.debug(r.stats)
if not r.rc == 0:
raise Exception("AnsibleRunner failed")
return r
else:
self.logger.warning(f"not running playbook {playbook}")
return None

def playbook_events(self, runner: ansible_runner.runner.Runner) -> dict:
for e in runner.events:
if 'event_data' in e:
yield e['event_data']

def playbook_task_res(self, runner: ansible_runner.runner.Runner, play: str, task: str) -> dict:
for e in self.playbook_events(runner):
if 'play' in e and play == e['play'] and 'task' in e and task == e['task'] and 'res' in e:
return e['res']

class GraphQlMixin:
"""
Expand Down
39 changes: 14 additions & 25 deletions modules/coact.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from urllib.parse import urlparse

# Import base classes from modules.base
from .base import GraphQlMixin, common_options, graphql_options, configure_logging_from_verbose
from .base import AnsibleRunner, GraphQlMixin, common_options, graphql_options, configure_logging_from_verbose
from .utils.graphql import GraphQlClient

# get local timezone
Expand All @@ -49,11 +49,6 @@ class OveragePoint(TypedDict):
change: bool
purchased_nodes: int

class FacilityNodeUsage(TypedDict):
facility: str
cluster: str
nodes: int


def parse_datetime(value: Any, timezone=_now.timezone, force_tz: bool = False):
"""Parse various datetime formats into pendulum DateTime objects."""
Expand Down Expand Up @@ -898,11 +893,8 @@ def overage(

def toggle_job_blocking(point: OveragePoint, execute: bool = False) -> bool:
"""Enable/disable job blocking for overaged allocations."""
template = Template("sacctmgr modify -i account name=$facility:_regular_@$cluster set GrpTRES=node=$nodes")

# Determine node count based on blocking state
if point['over']:
# Blocking: set to 0
nodes = 0
else:
# Unblocking: use purchased nodes or fallback to unlimited
Expand All @@ -916,24 +908,21 @@ def toggle_job_blocking(point: OveragePoint, execute: bool = False) -> bool:
logger.warning(f"Invalid node count {nodes} for {point['facility']}@{point['cluster']}, using unlimited")
nodes = -1

facility_usage = FacilityNodeUsage(
facility=point['facility'],
cluster=point['cluster'],
nodes=nodes
)

logger.info(f"Job blocking toggle for {facility_usage['facility']}@{facility_usage['cluster']}: nodes={nodes} (over={point['over']}, execute={execute})")
cmd = template.safe_substitute(**facility_usage)
logger.info(f"Command: {cmd}")
logger.info(f"Job blocking toggle for {point['facility']}@{point['cluster']}: nodes={nodes} (over={point['over']}, execute={execute})")

if execute:
try:
result = subprocess.check_output(cmd.split())
for line in result.split(b"\n"):
if line.strip():
logger.debug(f"sacctmgr output: {line.decode().strip()}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to toggle job blocking: {e}")
runner = AnsibleRunner()
runner.ident = "toggle_job_blocking"

r = runner.run_playbook(
playbook='coact/slurm/toggle-job-blocking.yaml',
facility=point['facility'],
cluster=point['cluster'],
nodes=nodes,
)
logger.debug(r.stats)
if r.rc != 0:
logger.error(f"Failed to toggle job blocking for {point['facility']}@{point['cluster']}")
return False

return True
Expand Down
51 changes: 1 addition & 50 deletions modules/coactd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Optional, List
from math import ceil
from timeit import default_timer as timer
from pathlib import Path

import click
import pendulum as pdl
Expand All @@ -26,20 +25,15 @@
import smtplib
from email.message import EmailMessage

import ansible_runner

# Import base classes from modules.base
from .base import GraphQlMixin, common_options, configure_logging_from_verbose
from .base import AnsibleRunner, common_options, configure_logging_from_verbose
from .utils.graphql import GraphQlSubscriber

# Using loguru logger

# Define context settings to support -h for help
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])

COACT_ANSIBLE_RUNNER_PATH = './ansible-runner/'


# Order of class inheritance important
class RequestStatus(str, Enum):
APPROVED = 'Approved'
Expand All @@ -49,49 +43,6 @@ class RequestStatus(str, Enum):
INCOMPLETE = 'Incomplete'


class AnsibleRunner:
"""Mixin class for running Ansible playbooks."""
# Using loguru logger
ident = None

def run_playbook(
self,
playbook: str,
private_data_dir: str = COACT_ANSIBLE_RUNNER_PATH,
tags: str = 'all',
dry_run: bool = False,
**kwargs
) -> Optional[ansible_runner.runner.Runner]:
name = Path(playbook).name
if not dry_run:
r = ansible_runner.run(
private_data_dir=private_data_dir,
playbook=playbook,
tags=tags,
extravars=kwargs,
suppress_env_files=True,
ident=f'{self.ident}_{name}:{tags}',
cancel_callback=lambda: None
)
self.logger.debug(r.stats)
if not r.rc == 0:
raise Exception("AnsibleRunner failed")
return r
else:
self.logger.warning(f"not running playbook {playbook}")
return None

def playbook_events(self, runner: ansible_runner.runner.Runner) -> dict:
for e in runner.events:
if 'event_data' in e:
yield e['event_data']

def playbook_task_res(self, runner: ansible_runner.runner.Runner, play: str, task: str) -> dict:
for e in self.playbook_events(runner):
if 'play' in e and play == e['play'] and 'task' in e and task == e['task'] and 'res' in e:
return e['res']


class EmailRunner:
"""Mixin class for sending emails."""
# Using loguru logger
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"pendulum>=3.2.0",
"tzlocal>=5.3.1",
"loguru>=0.7.0",
"setuptools<70", # Required by ansible-runner for pkg_resources (deprecated in v70+)
]

[project.optional-dependencies]
Expand Down
34 changes: 17 additions & 17 deletions tests/test_node_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
Unit tests for node allocation functionality.
"""

from unittest.mock import patch
from unittest.mock import patch, MagicMock

import pytest

from modules.coact import toggle_job_blocking, OveragePoint, FacilityUsage
from modules.coact import toggle_job_blocking, FacilityUsage


def create_graphql_response(usage_percent: float, nodes: int):
Expand Down Expand Up @@ -108,15 +106,16 @@ def test_facility_lifecycle_goes_over_blocks_recovers_and_restores_nodes():
assert overage_point['purchased_nodes'] == purchased_nodes, "OveragePoint should have purchasedNodes from format_data"

# Mock sacctmgr toggle to set nodes to 0
with patch('modules.coact.subprocess.check_output') as mock_subprocess:
mock_subprocess.return_value = b"Modified account\n"
with patch('modules.coact.AnsibleRunner.run_playbook') as mock_run:
mock_run.return_value = MagicMock(rc=0, stats={})
result = toggle_job_blocking(overage_point, execute=True)

# Verify blocking command was issued
assert result is True
called_args = mock_subprocess.call_args[0][0]
assert "GrpTRES=node=0" in called_args # Jobs blocked
assert f"name={facility}:_regular_@{cluster}" in called_args
call_kwargs = mock_run.call_args.kwargs
assert call_kwargs['nodes'] == 0
assert call_kwargs['facility'] == facility
assert call_kwargs['cluster'] == cluster

# After blocking, sacctmgr shows GrpNodes=0 (but GraphQL still has purchasedNodes)
sacctmgr_blocked = b"""lcls:_regular_@ada|0|1000|1000
Expand Down Expand Up @@ -159,16 +158,17 @@ def test_facility_lifecycle_goes_over_blocks_recovers_and_restores_nodes():
assert recovery_point['change'] is True # Need to unblock
assert recovery_point['purchased_nodes'] == purchased_nodes, "OveragePoint should have purchasedNodes from coact-api"

with patch('modules.coact.subprocess.check_output') as mock_subprocess:
mock_subprocess.return_value = b"Modified account\n"
with patch('modules.coact.AnsibleRunner.run_playbook') as mock_run:
mock_run.return_value = MagicMock(rc=0, stats={})
result = toggle_job_blocking(recovery_point, execute=True)

# Verify unblocking command uses original purchased nodes from coact-api
assert result is True
called_args = mock_subprocess.call_args[0][0]
assert f"GrpTRES=node={purchased_nodes}" in called_args # CRITICAL: restores 256, not -1
assert "GrpTRES=node=-1" not in called_args # NOT unlimited
assert f"name={facility}:_regular_@{cluster}" in called_args
call_kwargs = mock_run.call_args.kwargs
assert call_kwargs['nodes'] == purchased_nodes # CRITICAL: restores 256, not -1
assert call_kwargs['nodes'] != -1 # NOT unlimited
assert call_kwargs['facility'] == facility
assert call_kwargs['cluster'] == cluster

# Verify final state: sacctmgr shows nodes restored
sacctmgr_restored = b"""lcls:_regular_@ada|256|1000|1000
Expand Down
8 changes: 5 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.