diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 04da35a4d..e17f42faa 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -34,6 +34,11 @@ export NNODES=${NNODES:-1} SCRIPT_DIR=$(dirname "$(realpath "${BASH_SOURCE[0]}")") +# Align EXP default with run_local_pretrain.sh to avoid unknown names +if [[ -z "${EXP:-}" ]]; then + export EXP="${SCRIPT_DIR}/megatron/exp_pretrain.yaml" +fi + export LOG_DIR=${LOG_DIR:-"./output"} LOG_FILE="${LOG_DIR}/log_slurm_pretrain.txt" mkdir -p "$LOG_DIR" diff --git a/primus/backends/megatron/training/mlflow_artifacts.py b/primus/backends/megatron/training/mlflow_artifacts.py new file mode 100644 index 000000000..0837eccc3 --- /dev/null +++ b/primus/backends/megatron/training/mlflow_artifacts.py @@ -0,0 +1,328 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +MLflow Artifact Logging Utilities + +This module provides functions to upload trace files and log files to MLflow +when MLflow tracking is enabled. + +Features: +- Upload profiler trace files from all profiled ranks (including multi-node) +- Upload log files from all levels and all ranks +- Supports both local and distributed training scenarios + +Note: + Multi-node training requires shared storage (e.g., NFS) for artifact uploads. + Only the last rank (world_size - 1) performs the upload, so it must have + access to trace and log files from all nodes. If using node-local storage, + only files from the uploading node will be uploaded. +""" + +import glob +import os +import traceback +from typing import List, Optional + +from primus.modules.module_utils import log_rank_last + +# Note: This module is called on the last rank (where MLflow is initialized). +# Using log_rank_last ensures messages are visible. For warnings, we prefix +# with [WARNING] since warning_rank_last doesn't exist. +try: + from mlflow.exceptions import MlflowException +except ModuleNotFoundError: + + class MlflowException(Exception): + """Fallback exception when mlflow isn't available.""" + + +def _log_warning(msg: str) -> None: + """Log a warning message on the last rank.""" + log_rank_last(f"[WARNING] {msg}") + + +def _get_all_trace_files(tensorboard_dir: Optional[str]) -> List[str]: + """ + Find all profiler trace files in the tensorboard directory. + + Trace files are typically named like: + - *.pt.trace.json + - *.pt.trace.json.gz + + Args: + tensorboard_dir: Path to the tensorboard directory containing trace files + + Returns: + List of paths to trace files + """ + if not tensorboard_dir or not os.path.exists(tensorboard_dir): + return [] + + trace_files = [] + # Look for PyTorch profiler trace files (both compressed and uncompressed) + patterns = ["*.pt.trace.json", "*.pt.trace.json.gz"] + # Escape directory path to handle special characters like [] in experiment names + escaped_dir = glob.escape(tensorboard_dir) + for pattern in patterns: + trace_files.extend(glob.glob(os.path.join(escaped_dir, pattern))) + trace_files.extend(glob.glob(os.path.join(escaped_dir, "**", pattern), recursive=True)) + + # Remove duplicates while preserving order + seen = set() + unique_files = [] + for f in trace_files: + if f not in seen: + seen.add(f) + unique_files.append(f) + + return unique_files + + +def _get_all_log_files(exp_root_path: Optional[str]) -> List[str]: + """ + Find all log files in the experiment logs directory. + + Log files are organized as: + - {exp_root_path}/logs/master/master-*.log + - {exp_root_path}/logs/{module_name}/rank-{rank}/*.log + + Args: + exp_root_path: Root path of the experiment + + Returns: + List of paths to log files + """ + if not exp_root_path: + return [] + + logs_dir = os.path.join(exp_root_path, "logs") + if not os.path.exists(logs_dir): + return [] + + log_files = [] + # Find all .log files recursively (escape path to handle special characters) + log_files.extend(glob.glob(os.path.join(glob.escape(logs_dir), "**", "*.log"), recursive=True)) + + return log_files + + +def upload_trace_files_to_mlflow( + mlflow_writer, + tensorboard_dir: str, + artifact_path: str = "traces", +) -> int: + """ + Upload all profiler trace files to MLflow as artifacts. + + This function collects trace files from the tensorboard directory and + uploads them to MLflow. In distributed settings, only the last rank + (world_size - 1) where MLflow writer is initialized should call this. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + tensorboard_dir: Path to the tensorboard directory containing trace files + artifact_path: MLflow artifact subdirectory for trace files + + Returns: + Number of trace files uploaded + """ + if mlflow_writer is None: + return 0 + + log_rank_last(f"[MLflow] Searching for trace files in: {tensorboard_dir}") + trace_files = _get_all_trace_files(tensorboard_dir) + if len(trace_files) > 5: + log_rank_last(f"[MLflow] Found {len(trace_files)} trace files: {trace_files[:5]}...") + else: + log_rank_last(f"[MLflow] Found {len(trace_files)} trace files: {trace_files}") + + if not trace_files: + log_rank_last("[MLflow] No trace files found to upload") + return 0 + + total_files = len(trace_files) + + # Warn about potentially long upload times for large uploads + if total_files > 10: + # Safely calculate total size (files may be deleted between discovery and size check) + total_size_bytes = 0 + for f in trace_files: + try: + total_size_bytes += os.path.getsize(f) + except OSError: + pass # File may have been deleted + total_size_mb = total_size_bytes / (1024 * 1024) + _log_warning( + f"[MLflow] Uploading {total_files} trace files ({total_size_mb:.1f} MB total). " + "This may take a while..." + ) + + uploaded_count = 0 + for trace_file in trace_files: + try: + # Get relative path from tensorboard_dir for artifact organization + rel_path = os.path.relpath(trace_file, tensorboard_dir) + # Determine artifact subdirectory based on file location + artifact_subpath = ( + os.path.join(artifact_path, os.path.dirname(rel_path)) + if os.path.dirname(rel_path) + else artifact_path + ) + + mlflow_writer.log_artifact(trace_file, artifact_path=artifact_subpath) + uploaded_count += 1 + # Progress logging with counter + log_rank_last( + f"[MLflow] Uploaded trace file ({uploaded_count}/{total_files}): " + f"{os.path.basename(trace_file)}" + ) + except (OSError, RuntimeError, ValueError, MlflowException) as e: + _log_warning(f"[MLflow] Failed to upload trace file {trace_file}: {type(e).__name__}: {e}") + _log_warning(traceback.format_exc().strip()) + + log_rank_last(f"[MLflow] Uploaded {uploaded_count}/{total_files} trace files to '{artifact_path}'") + return uploaded_count + + +def upload_log_files_to_mlflow( + mlflow_writer, + exp_root_path: str, + artifact_path: str = "logs", +) -> int: + """ + Upload all log files to MLflow as artifacts. + + This function collects log files from all ranks and all log levels + and uploads them to MLflow. The directory structure is preserved + in the artifact path. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + exp_root_path: Root path of the experiment + artifact_path: MLflow artifact subdirectory for log files + + Returns: + Number of log files uploaded + """ + if mlflow_writer is None: + return 0 + + log_files = _get_all_log_files(exp_root_path) + + if not log_files: + log_rank_last("[MLflow] No log files found to upload") + return 0 + + total_files = len(log_files) + + # Warn about potentially long upload times for large uploads + if total_files > 20: + # Safely calculate total size (files may be deleted between discovery and size check) + total_size_bytes = 0 + for f in log_files: + try: + total_size_bytes += os.path.getsize(f) + except OSError: + pass # File may have been deleted + total_size_mb = total_size_bytes / (1024 * 1024) + _log_warning( + f"[MLflow] Uploading {total_files} log files ({total_size_mb:.1f} MB total). " + "This may take a while..." + ) + + logs_base_dir = os.path.join(exp_root_path, "logs") + uploaded_count = 0 + + for log_file in log_files: + try: + # Preserve directory structure relative to logs base directory + rel_path = os.path.relpath(log_file, logs_base_dir) + artifact_subpath = ( + os.path.join(artifact_path, os.path.dirname(rel_path)) + if os.path.dirname(rel_path) + else artifact_path + ) + + mlflow_writer.log_artifact(log_file, artifact_path=artifact_subpath) + uploaded_count += 1 + except (OSError, RuntimeError, ValueError, MlflowException) as e: + _log_warning(f"[MLflow] Failed to upload log file {log_file}: {type(e).__name__}: {e}") + _log_warning(traceback.format_exc().strip()) + + log_rank_last(f"[MLflow] Uploaded {uploaded_count}/{total_files} log files to '{artifact_path}'") + return uploaded_count + + +def upload_artifacts_to_mlflow( + mlflow_writer, + tensorboard_dir: Optional[str] = None, + exp_root_path: Optional[str] = None, + upload_traces: bool = True, + upload_logs: bool = True, +) -> dict: + """ + Upload all artifacts (trace files and log files) to MLflow. + + This is the main entry point for uploading artifacts to MLflow. + It handles both trace files from profiling and log files from training. + + Note: + Multi-node training requires shared storage (e.g., NFS) for complete + artifact uploads. Only the last rank performs the upload, so it must + have filesystem access to trace/log files from all nodes. + + Args: + mlflow_writer: The MLflow module instance (from get_mlflow_writer()) + tensorboard_dir: Path to the tensorboard directory containing trace files + exp_root_path: Root path of the experiment for log files + upload_traces: Whether to upload trace files + upload_logs: Whether to upload log files + + Returns: + Dictionary with counts of uploaded files: + { + "traces": , + "logs": + } + """ + if mlflow_writer is None: + log_rank_last("[MLflow] MLflow writer not available, skipping artifact upload") + return {"traces": 0, "logs": 0} + + # Warn about multi-node shared storage requirement + try: + nnodes = int(os.environ.get("NNODES", os.environ.get("SLURM_NNODES", "1"))) + except ValueError: + nnodes = 1 + _log_warning("[MLflow] NNODES/SLURM_NNODES could not be parsed as integer; assuming 1 node.") + if nnodes > 1: + _log_warning( + f"[MLflow] Multi-node training detected ({nnodes} nodes). " + "Ensure shared storage (e.g., NFS) is used for complete artifact uploads. " + "Only files accessible from this node will be uploaded." + ) + + log_rank_last("[MLflow] Starting artifact upload to MLflow...") + log_rank_last(f"[MLflow] tensorboard_dir: {tensorboard_dir}") + log_rank_last(f"[MLflow] exp_root_path: {exp_root_path}") + log_rank_last(f"[MLflow] upload_traces: {upload_traces}, upload_logs: {upload_logs}") + + result = {"traces": 0, "logs": 0} + + if upload_traces and tensorboard_dir: + result["traces"] = upload_trace_files_to_mlflow( + mlflow_writer, tensorboard_dir, artifact_path="traces" + ) + + if upload_logs and exp_root_path: + result["logs"] = upload_log_files_to_mlflow(mlflow_writer, exp_root_path, artifact_path="logs") + + log_rank_last( + f"[MLflow] Artifact upload complete: {result['traces']} trace files, {result['logs']} log files" + ) + + return result diff --git a/primus/backends/megatron/training/mlflow_setup.py b/primus/backends/megatron/training/mlflow_setup.py new file mode 100644 index 000000000..33a15ae10 --- /dev/null +++ b/primus/backends/megatron/training/mlflow_setup.py @@ -0,0 +1,69 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +MLflow artifact upload utilities. + +This module provides functions for uploading artifacts (traces, logs) to MLflow. +Separated from global_vars.py to reduce merge conflicts. +""" + +from .global_vars import get_args, get_mlflow_writer + +_GLOBAL_EXP_ROOT_PATH = None + + +def set_exp_root_path(exp_root_path): + """Set the experiment root path for artifact logging.""" + global _GLOBAL_EXP_ROOT_PATH + _GLOBAL_EXP_ROOT_PATH = exp_root_path + + +def get_exp_root_path(): + """Return experiment root path. Can be None.""" + return _GLOBAL_EXP_ROOT_PATH + + +def reset_exp_root_path(): + """Reset the experiment root path to None.""" + global _GLOBAL_EXP_ROOT_PATH + _GLOBAL_EXP_ROOT_PATH = None + + +def upload_mlflow_artifacts( + upload_traces: bool = True, + upload_logs: bool = True, +): + """ + Upload trace files and log files to MLflow as artifacts. + + This should be called before ending the MLflow run to ensure all + artifacts are uploaded. Only the rank that initialized MLflow + (typically rank world_size - 1) should call this. + + Args: + upload_traces: Whether to upload profiler trace files + upload_logs: Whether to upload training log files + + Returns: + Dictionary with counts of uploaded files, or None if MLflow is not enabled + """ + mlflow_writer = get_mlflow_writer() + if mlflow_writer is None: + return None + + from .mlflow_artifacts import upload_artifacts_to_mlflow + + args = get_args() + exp_root_path = get_exp_root_path() + tensorboard_dir = getattr(args, "tensorboard_dir", None) + + return upload_artifacts_to_mlflow( + mlflow_writer=mlflow_writer, + tensorboard_dir=tensorboard_dir, + exp_root_path=exp_root_path, + upload_traces=upload_traces, + upload_logs=upload_logs, + ) diff --git a/primus/configs/modules/megatron/primus_megatron_module.yaml b/primus/configs/modules/megatron/primus_megatron_module.yaml index 0ec3a22b0..74f46f257 100644 --- a/primus/configs/modules/megatron/primus_megatron_module.yaml +++ b/primus/configs/modules/megatron/primus_megatron_module.yaml @@ -5,6 +5,10 @@ disable_wandb: true disable_mlflow: true mlflow_run_name: null mlflow_experiment_name: null +# NOTE: When disable_mlflow=false, traces and logs are uploaded by default. +# Set these to false if you only want metrics/params logged to MLflow. +mlflow_upload_traces: true # Upload profiler trace files to MLflow +mlflow_upload_logs: true # Upload training log files to MLflow disable_compile_dependencies: true # NOTE: # - If `use_rocm_mem_info = True`, ROCm memory information will be collected diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 5db59188b..4b41844b7 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -148,6 +148,11 @@ set_primus_global_variables, set_train_start_time, ) +from primus.backends.megatron.training.mlflow_setup import ( + reset_exp_root_path, + set_exp_root_path, + upload_mlflow_artifacts, +) from primus.backends.megatron.training.tokenizer.tokenizer import build_tokenizer from primus.core.utils import checker, file_utils from primus.core.utils.rocm_mem_info import get_rocm_smi_mem_info @@ -174,6 +179,42 @@ set_train_start_time() +def _finalize_mlflow_run(args, mlflow_writer) -> None: + """ + Finalize MLflow run: sync ranks, upload artifacts, and end the run. + + This helper function consolidates the MLflow finalization logic to avoid + code duplication between normal training completion and exit conditions. + + Args: + args: Megatron arguments containing mlflow_upload_traces/logs settings + mlflow_writer: The MLflow writer instance (or None if not enabled) + """ + # Barrier to ensure all ranks have finished writing files before upload. + # Must run on ALL ranks to avoid deadlock (only last rank has mlflow_writer). + if dist.is_initialized(): + dist.barrier() + + if mlflow_writer is None: + reset_exp_root_path() + if dist.is_initialized(): + dist.barrier() + return + + # Upload artifacts before ending the run + mlflow_artifact_kwargs = {} + if hasattr(args, "mlflow_upload_traces"): + mlflow_artifact_kwargs["upload_traces"] = args.mlflow_upload_traces + if hasattr(args, "mlflow_upload_logs"): + mlflow_artifact_kwargs["upload_logs"] = args.mlflow_upload_logs + upload_mlflow_artifacts(**mlflow_artifact_kwargs) + mlflow_writer.end_run() + # Reset so subsequent runs in the same process don't use a stale path + reset_exp_root_path() + if dist.is_initialized(): + dist.barrier() + + class MegatronTrainer(BaseTrainer, BaseModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -395,6 +436,24 @@ def update_primus_config( else: log_rank_0(f"-{latest_file} does not exist, skip auto_continue_train.") + # If uploading traces (or tracelens) to MLflow, auto-enable profiling and tensorboard. + # Only when MLflow is enabled (disable_mlflow=False); we do not override disable_mlflow + # so MLflow remains opt-in and users with disable_mlflow: true are not surprised. + needs_profiling = ( + getattr(args, "mlflow_upload_traces", False) + or getattr(args, "mlflow_upload_tracelens_report", False) + ) and not args.disable_mlflow + if needs_profiling: + if not getattr(args, "profile", False): + args.profile = True + debug_rank_0("Auto-enabled profile=True for mlflow trace upload") + if not getattr(args, "use_pytorch_profiler", False): + args.use_pytorch_profiler = True + debug_rank_0("Auto-enabled use_pytorch_profiler=True for mlflow trace upload") + if getattr(args, "disable_tensorboard", True): + args.disable_tensorboard = False + debug_rank_0("Auto-enabled tensorboard (disable_tensorboard=False) for profiler trace output") + # tensorboard if not args.disable_tensorboard: tb_path = os.path.abspath(os.path.join(exp_root_path, "tensorboard")) @@ -752,6 +811,8 @@ def initialize_megatron( set_global_variables(args, build_tokenizer=False) log_rank_0(f"-set_primus_global_variables...") set_primus_global_variables(args) + # Set exp_root_path for MLflow artifact upload (needed before training starts) + set_exp_root_path(self.exp_root_path) args = get_args() # set tokenizer @@ -1120,8 +1181,7 @@ def run(self, *args, **kwargs): ft_integration.on_checkpointing_end(is_async_finalization=True) mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + _finalize_mlflow_run(args, mlflow_writer) one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) @@ -1564,8 +1624,7 @@ def get_e2e_base_metrics(): if wandb_writer: wandb_writer.finish() mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + _finalize_mlflow_run(args, mlflow_writer) ft_integration.shutdown() sys.exit(exit_code) diff --git a/tests/unit_tests/backends/megatron/test_mlflow_artifacts.py b/tests/unit_tests/backends/megatron/test_mlflow_artifacts.py new file mode 100644 index 000000000..60eda4703 --- /dev/null +++ b/tests/unit_tests/backends/megatron/test_mlflow_artifacts.py @@ -0,0 +1,272 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Unit tests for MLflow artifact upload utilities. + +Focus areas: + 1. File discovery logic (_get_all_trace_files, _get_all_log_files) + 2. Upload functions with various scenarios (no files, multiple files, errors) + 3. Glob escaping for special characters in paths + 4. Relative path handling for artifact organization +""" + +from unittest.mock import MagicMock, patch + +from primus.backends.megatron.training.mlflow_artifacts import ( + _get_all_log_files, + _get_all_trace_files, + upload_artifacts_to_mlflow, + upload_log_files_to_mlflow, + upload_trace_files_to_mlflow, +) + + +class TestGetAllTraceFiles: + """Test trace file discovery logic.""" + + def test_finds_json_trace_files(self, tmp_path): + """Should find .pt.trace.json files.""" + trace_file = tmp_path / "rank_0_step_2.pt.trace.json" + trace_file.touch() + + files = _get_all_trace_files(str(tmp_path)) + + assert len(files) == 1 + assert str(trace_file) in files + + def test_finds_gzipped_trace_files(self, tmp_path): + """Should find .pt.trace.json.gz files.""" + trace_file = tmp_path / "rank_0_step_2.pt.trace.json.gz" + trace_file.touch() + + files = _get_all_trace_files(str(tmp_path)) + + assert len(files) == 1 + assert str(trace_file) in files + + def test_finds_nested_trace_files(self, tmp_path): + """Should find trace files in subdirectories.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + trace_file = subdir / "rank_1.pt.trace.json" + trace_file.touch() + + files = _get_all_trace_files(str(tmp_path)) + + assert len(files) == 1 + assert str(trace_file) in files + + def test_returns_empty_for_nonexistent_dir(self): + """Should return empty list for non-existent directory.""" + files = _get_all_trace_files("/nonexistent/path") + + assert files == [] + + def test_returns_empty_for_none(self): + """Should return empty list for None input.""" + files = _get_all_trace_files(None) + + assert files == [] + + def test_handles_special_characters_in_path(self, tmp_path): + """Should handle paths with special glob characters like [].""" + # Create directory with brackets in name (common in experiment names) + special_dir = tmp_path / "exp[rank0]_test" + special_dir.mkdir() + trace_file = special_dir / "trace.pt.trace.json" + trace_file.touch() + + files = _get_all_trace_files(str(special_dir)) + + assert len(files) == 1 + assert str(trace_file) in files + + def test_deduplicates_files(self, tmp_path): + """Should not return duplicate file paths.""" + trace_file = tmp_path / "rank_0.pt.trace.json" + trace_file.touch() + + files = _get_all_trace_files(str(tmp_path)) + + # Each file should appear only once + assert len(files) == len(set(files)) + + +class TestGetAllLogFiles: + """Test log file discovery logic.""" + + def test_finds_log_files(self, tmp_path): + """Should find .log files in logs directory.""" + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + log_file = logs_dir / "training.log" + log_file.touch() + + files = _get_all_log_files(str(tmp_path)) + + assert len(files) == 1 + assert str(log_file) in files + + def test_finds_nested_log_files(self, tmp_path): + """Should find log files in nested directories.""" + logs_dir = tmp_path / "logs" / "rank-0" + logs_dir.mkdir(parents=True) + log_file = logs_dir / "debug.log" + log_file.touch() + + files = _get_all_log_files(str(tmp_path)) + + assert len(files) == 1 + assert str(log_file) in files + + def test_returns_empty_when_no_logs_dir(self, tmp_path): + """Should return empty list when logs directory doesn't exist.""" + files = _get_all_log_files(str(tmp_path)) + + assert files == [] + + def test_returns_empty_for_none(self): + """Should return empty list for None input.""" + files = _get_all_log_files(None) + + assert files == [] + + +class TestUploadTraceFilesToMlflow: + """Test trace file upload functionality.""" + + def test_returns_zero_when_no_writer(self, tmp_path): + """Should return 0 when mlflow_writer is None.""" + count = upload_trace_files_to_mlflow(None, str(tmp_path)) + + assert count == 0 + + def test_returns_zero_when_no_files(self, tmp_path): + """Should return 0 when no trace files found.""" + mlflow_mock = MagicMock() + + count = upload_trace_files_to_mlflow(mlflow_mock, str(tmp_path)) + + assert count == 0 + mlflow_mock.log_artifact.assert_not_called() + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_uploads_trace_files(self, mock_warning, mock_log, tmp_path): + """Should upload trace files and return count.""" + trace_file = tmp_path / "rank_0.pt.trace.json" + trace_file.touch() + mlflow_mock = MagicMock() + + count = upload_trace_files_to_mlflow(mlflow_mock, str(tmp_path)) + + assert count == 1 + mlflow_mock.log_artifact.assert_called_once() + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_handles_upload_error(self, mock_warning, mock_log, tmp_path): + """Should continue on upload error and log warning.""" + trace_file = tmp_path / "rank_0.pt.trace.json" + trace_file.touch() + mlflow_mock = MagicMock() + mlflow_mock.log_artifact.side_effect = RuntimeError("Upload failed") + + count = upload_trace_files_to_mlflow(mlflow_mock, str(tmp_path)) + + assert count == 0 + mock_warning.assert_called() + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_preserves_relative_path(self, mock_warning, mock_log, tmp_path): + """Should preserve subdirectory structure in artifact path.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + trace_file = subdir / "rank_0.pt.trace.json" + trace_file.touch() + mlflow_mock = MagicMock() + + upload_trace_files_to_mlflow(mlflow_mock, str(tmp_path)) + + # Check that artifact_path includes subdirectory + call_args = mlflow_mock.log_artifact.call_args + assert "subdir" in call_args.kwargs.get("artifact_path", "") + + +class TestUploadLogFilesToMlflow: + """Test log file upload functionality.""" + + def test_returns_zero_when_no_writer(self, tmp_path): + """Should return 0 when mlflow_writer is None.""" + count = upload_log_files_to_mlflow(None, str(tmp_path)) + + assert count == 0 + + def test_returns_zero_when_no_files(self, tmp_path): + """Should return 0 when no log files found.""" + mlflow_mock = MagicMock() + + count = upload_log_files_to_mlflow(mlflow_mock, str(tmp_path)) + + assert count == 0 + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_uploads_log_files(self, mock_warning, mock_log, tmp_path): + """Should upload log files and return count.""" + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + log_file = logs_dir / "training.log" + log_file.touch() + mlflow_mock = MagicMock() + + count = upload_log_files_to_mlflow(mlflow_mock, str(tmp_path)) + + assert count == 1 + mlflow_mock.log_artifact.assert_called_once() + + +class TestUploadArtifactsToMlflow: + """Test main artifact upload entry point.""" + + def test_returns_zeros_when_no_writer(self): + """Should return zero counts when mlflow_writer is None.""" + result = upload_artifacts_to_mlflow(None) + + assert result == {"traces": 0, "logs": 0} + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_respects_upload_traces_flag(self, mock_warning, mock_log, tmp_path): + """Should skip trace upload when upload_traces=False.""" + trace_file = tmp_path / "rank_0.pt.trace.json" + trace_file.touch() + mlflow_mock = MagicMock() + + result = upload_artifacts_to_mlflow( + mlflow_mock, + tensorboard_dir=str(tmp_path), + upload_traces=False, + upload_logs=False, + ) + + assert result["traces"] == 0 + mlflow_mock.log_artifact.assert_not_called() + + @patch("primus.backends.megatron.training.mlflow_artifacts.log_rank_last") + @patch("primus.backends.megatron.training.mlflow_artifacts._log_warning") + def test_warns_for_multi_node(self, mock_warning, mock_log, tmp_path, monkeypatch): + """Should warn when multi-node training is detected.""" + monkeypatch.setenv("NNODES", "2") + mlflow_mock = MagicMock() + + upload_artifacts_to_mlflow(mlflow_mock, tensorboard_dir=str(tmp_path)) + + # Check that warning was called with multi-node message + warning_calls = [str(call) for call in mock_warning.call_args_list] + assert any("Multi-node" in str(call) for call in warning_calls)