diff --git a/tests/conftest.py b/tests/conftest.py index 7dd7149fe..06a693366 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,6 +67,7 @@ def slurm_system(tmp_path: Path) -> SlurmSystem: ) system.scheduler = "slurm" system.monitor_interval = 0 + system.supports_gpu_directives_cache = True return system diff --git a/tests/systems/slurm/test_system.py b/tests/systems/slurm/test_system.py index 0ad79b61f..df2cc5d9f 100644 --- a/tests/systems/slurm/test_system.py +++ b/tests/systems/slurm/test_system.py @@ -16,7 +16,6 @@ import re from pathlib import Path -from typing import Dict, List from unittest.mock import Mock, patch import pytest @@ -145,14 +144,28 @@ def grouped_nodes() -> dict[SlurmNodeState, list[SlurmNode]]: return grouped_nodes -def test_get_available_nodes_exceeding_limit_no_callstack( - slurm_system: SlurmSystem, grouped_nodes: Dict[SlurmNodeState, List[SlurmNode]], caplog -): +def test_get_available_nodes_exceeding_limit_no_callstack(slurm_system: SlurmSystem, caplog): group_name = "group1" partition_name = "main" num_nodes = 5 + empty_grouped_nodes = { + SlurmNodeState.IDLE: [], + SlurmNodeState.COMPLETING: [], + SlurmNodeState.ALLOCATED: [], + } + + mod_path = "cloudai.systems.slurm.slurm_system.SlurmSystem" + with ( + patch(f"{mod_path}.update", return_value=None) as mock_update, + patch(f"{mod_path}.group_nodes_by_state", return_value=empty_grouped_nodes) as mock_group_nodes_by_state, + ): + slurm_system.get_available_nodes_from_group(partition_name, group_name, num_nodes) - slurm_system.get_available_nodes_from_group(partition_name, group_name, num_nodes) + mock_update.assert_called_once() + mock_group_nodes_by_state.assert_called_once() + args, kwargs = mock_group_nodes_by_state.call_args + assert args == (partition_name, group_name) + assert kwargs in ({}, {"exclude_nodes": None}) log_message = "CloudAI is requesting 5 nodes from the group 'group1', but only 0 nodes are available." assert log_message in caplog.text @@ -513,10 +526,27 @@ def test_per_step_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSyste def test_supports_gpu_directives( mock_fetch_command_output, scontrol_output: str, expected_support: bool, slurm_system: SlurmSystem ): + slurm_system.supports_gpu_directives_cache = None mock_fetch_command_output.return_value = (scontrol_output, "") assert slurm_system.supports_gpu_directives == expected_support +@patch("cloudai.systems.slurm.slurm_system.SlurmSystem.fetch_command_output") +def test_supports_gpu_directives_defaults_to_true_on_probe_error( + mock_fetch_command_output, slurm_system: SlurmSystem, caplog: pytest.LogCaptureFixture +): + slurm_system.supports_gpu_directives_cache = None + stderr = "scontrol failed" + mock_fetch_command_output.return_value = ("", stderr) + + with caplog.at_level("WARNING"): + assert slurm_system.supports_gpu_directives is True + + assert slurm_system.supports_gpu_directives_cache is True + assert f"Error checking GPU support: {stderr}" in caplog.text + mock_fetch_command_output.assert_called_once_with("scontrol show config") + + @pytest.mark.parametrize( "cache_value", [True, False],