From a0afc82ecec71519f659c55f42008588728d64ca Mon Sep 17 00:00:00 2001 From: Mikael Strauss Date: Wed, 10 Dec 2025 16:52:29 +0000 Subject: [PATCH 1/3] mlflow ending status --- .../backends/megatron/training/global_vars.py | 30 +++ primus/modules/trainer/megatron/trainer.py | 217 ++++++++++-------- 2 files changed, 146 insertions(+), 101 deletions(-) diff --git a/primus/backends/megatron/training/global_vars.py b/primus/backends/megatron/training/global_vars.py index b23016d46..56fc19151 100644 --- a/primus/backends/megatron/training/global_vars.py +++ b/primus/backends/megatron/training/global_vars.py @@ -62,6 +62,36 @@ def _set_mlflow_writer(args): _GLOBAL_MLFLOW_WRITER = mlflow +def end_mlflow_run(status="FINISHED", termination_reason=None): + """End MLflow run with specified status. + + Args: + status: MLflow run status - "FINISHED", "FAILED", or "KILLED". + termination_reason: Optional free-form string tag recorded as + "termination_reason" on the run. + """ + global _GLOBAL_MLFLOW_WRITER + + if _GLOBAL_MLFLOW_WRITER is None: + return + + try: + # Optionally attach a coarse termination reason tag for debugging. + if termination_reason is not None: + try: + _GLOBAL_MLFLOW_WRITER.set_tag("termination_reason", termination_reason) + except Exception: + # Ignore tagging failures; status update is more important. + pass + + _GLOBAL_MLFLOW_WRITER.end_run(status=status) + except Exception: + # Swallow MLflow/network errors to avoid masking the original failure. + pass + finally: + _GLOBAL_MLFLOW_WRITER = None + + def unset_global_variables(): """Unset global vars.""" diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 19783a798..012040b73 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -143,6 +143,7 @@ from primus.backends.megatron.core.transformer.moe.moe_utils import track_moe_metrics from primus.backends.megatron.model_provider import primus_model_provider from primus.backends.megatron.training.global_vars import ( + end_mlflow_run, get_mlflow_writer, set_primus_global_variables, ) @@ -1475,122 +1476,136 @@ def run(self, *args, **kwargs): one_logger = get_one_logger() args = get_args() - if args.pp_warmup: - from .utils import pp_warmup + try: + if args.pp_warmup: + from .utils import pp_warmup - log_rank_0( - "warmup on each rank in parallel to decrease " - "the first iter time, especially when pp degree is large" - ) - timers = get_timers() - timers("pp-warmup", log_level=0).start(barrier=True) - pp_warmup(args, self.config, self.model, self.optimizer) - timers("pp-warmup").stop() - timers.log(["pp-warmup"], barrier=True) - - process_non_loss_data_func = None - non_loss_data_func = None - if not args.skip_train: - log_rank_0("training ...") - - if args.dataloader_type == "cyclic" and args.retro_project_dir: - assert args.retro_cyclic_train_iters is not None - args.train_iters = args.retro_cyclic_train_iters - log_rank_0("retro cyclic train iters : %d" % args.train_iters) - - iteration = 0 - if args.do_train and args.train_iters > 0: - iteration, num_floating_point_operations_so_far = self.train( - self.forward_step, - self.model, - self.optimizer, - self.opt_param_scheduler, - self.train_data_iterator, - self.valid_data_iterator, - process_non_loss_data_func, - self.config, - self.checkpointing_context, - non_loss_data_func, + log_rank_0( + "warmup on each rank in parallel to decrease " + "the first iter time, especially when pp degree is large" ) + timers = get_timers() + timers("pp-warmup", log_level=0).start(barrier=True) + pp_warmup(args, self.config, self.model, self.optimizer) + timers("pp-warmup").stop() + timers.log(["pp-warmup"], barrier=True) + + process_non_loss_data_func = None + non_loss_data_func = None + if not args.skip_train: + log_rank_0("training ...") + + if args.dataloader_type == "cyclic" and args.retro_project_dir: + assert args.retro_cyclic_train_iters is not None + args.train_iters = args.retro_cyclic_train_iters + log_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration, num_floating_point_operations_so_far = self.train( + self.forward_step, + self.model, + self.optimizer, + self.opt_param_scheduler, + self.train_data_iterator, + self.valid_data_iterator, + process_non_loss_data_func, + self.config, + self.checkpointing_context, + non_loss_data_func, + ) - print_datetime("after training is done") + print_datetime("after training is done") - if ( - args.save - and iteration != 0 - and iteration % args.save_interval != 0 - and not args.disable_last_saving - ): - save_checkpoint( - iteration, - self.model, - self.optimizer, - self.opt_param_scheduler, - num_floating_point_operations_so_far, - self.checkpointing_context, - train_data_iterator=self.train_data_iterator, - preprocess_common_state_dict_fn=preprocess_common_state_dict, - ) + if ( + args.save + and iteration != 0 + and iteration % args.save_interval != 0 + and not args.disable_last_saving + ): + save_checkpoint( + iteration, + self.model, + self.optimizer, + self.opt_param_scheduler, + num_floating_point_operations_so_far, + self.checkpointing_context, + train_data_iterator=self.train_data_iterator, + preprocess_common_state_dict_fn=preprocess_common_state_dict, + ) - one_logger and one_logger.log_metrics( - {"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()} - ) + one_logger and one_logger.log_metrics( + {"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()} + ) - else: - log_rank_0("skipping training (--skip-train is on) ...") + else: + log_rank_0("skipping training (--skip-train is on) ...") - iteration = args.iteration + iteration = args.iteration - if args.do_valid: - prefix = f"iteration {iteration} on validation set" - evaluate_and_print_results( - prefix, - self.forward_step, - self.valid_data_iterator, - self.model, - iteration, - process_non_loss_data_func, - self.config, - verbose=True, - write_to_tensorboard=not args.skip_train, - non_loss_data_func=non_loss_data_func, - ) + if args.do_valid: + prefix = f"iteration {iteration} on validation set" + evaluate_and_print_results( + prefix, + self.forward_step, + self.valid_data_iterator, + self.model, + iteration, + process_non_loss_data_func, + self.config, + verbose=True, + write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func, + ) - if args.do_test: - prefix = f"iteration {iteration} on test set" - evaluate_and_print_results( - prefix, - self.forward_step, - self.test_data_iterator, - self.model, - iteration, - process_non_loss_data_func, - self.config, - verbose=True, - write_to_tensorboard=not args.skip_train, - non_loss_data_func=non_loss_data_func, - ) + if args.do_test: + prefix = f"iteration {iteration} on test set" + evaluate_and_print_results( + prefix, + self.forward_step, + self.test_data_iterator, + self.model, + iteration, + process_non_loss_data_func, + self.config, + verbose=True, + write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func, + ) - wandb_writer = get_wandb_writer() - if wandb_writer: - wandb_writer.finish() + wandb_writer = get_wandb_writer() + if wandb_writer: + wandb_writer.finish() - ft_integration.on_checkpointing_start() - maybe_finalize_async_save(blocking=True, terminate=True) - ft_integration.on_checkpointing_end(is_async_finalization=True) + ft_integration.on_checkpointing_start() + maybe_finalize_async_save(blocking=True, terminate=True) + ft_integration.on_checkpointing_end(is_async_finalization=True) - mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + # End MLflow run with clean "FINISHED" status + if args.rank == (args.world_size - 1): + end_mlflow_run(status="FINISHED", termination_reason="clean_finish") - one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) + one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) - ft_integration.shutdown() - one_logger_utils.finish() + ft_integration.shutdown() + one_logger_utils.finish() + + # clean up torch pg resources on exit + if dist.is_initialized(): + dist.destroy_process_group() + + except KeyboardInterrupt: + # User interrupted the run (e.g., Ctrl+C). Mark as FAILED on MLflow host rank, + # with a specific termination_reason indicating it was user-killed. + if args.rank == (args.world_size - 1): + end_mlflow_run(status="FAILED", termination_reason="keyboard_interrupt") + raise + except Exception as e: + # Any other unhandled error: mark run as FAILED on the MLflow host rank. + if args.rank == (args.world_size - 1): + end_mlflow_run(status="FAILED", termination_reason=reason) + raise - # clean up torch pg resources on exit - if dist.is_initialized(): - dist.destroy_process_group() def train( self, From dc876e4f99d68a02402a784621a08b19c5fba31b Mon Sep 17 00:00:00 2001 From: mvstrauss Date: Tue, 20 Jan 2026 07:27:55 +0000 Subject: [PATCH 2/3] fix(megatron): ensure mlflow run ends on exit --- primus/modules/trainer/megatron/trainer.py | 49 ++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 012040b73..512dae50e 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -1476,6 +1476,9 @@ def run(self, *args, **kwargs): one_logger = get_one_logger() args = get_args() + mlflow_status = None + termination_reason = None + try: if args.pp_warmup: from .utils import pp_warmup @@ -1581,31 +1584,43 @@ def run(self, *args, **kwargs): maybe_finalize_async_save(blocking=True, terminate=True) ft_integration.on_checkpointing_end(is_async_finalization=True) - # End MLflow run with clean "FINISHED" status - if args.rank == (args.world_size - 1): - end_mlflow_run(status="FINISHED", termination_reason="clean_finish") + mlflow_status = "FINISHED" + termination_reason = "clean_finish" one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) ft_integration.shutdown() one_logger_utils.finish() - # clean up torch pg resources on exit - if dist.is_initialized(): - dist.destroy_process_group() - except KeyboardInterrupt: - # User interrupted the run (e.g., Ctrl+C). Mark as FAILED on MLflow host rank, - # with a specific termination_reason indicating it was user-killed. - if args.rank == (args.world_size - 1): - end_mlflow_run(status="FAILED", termination_reason="keyboard_interrupt") + mlflow_status = "KILLED" + termination_reason = "keyboard_interrupt" + raise + except SystemExit as e: + # sys.exit() raises SystemExit (not an Exception). Preserve behavior, but + # still mark MLflow run with a meaningful terminal status in `finally`. + exit_code = e.code if isinstance(e.code, int) else 0 + mlflow_status = "FINISHED" if exit_code == 0 else "FAILED" + termination_reason = f"system_exit_{exit_code}" raise except Exception as e: - # Any other unhandled error: mark run as FAILED on the MLflow host rank. - if args.rank == (args.world_size - 1): - end_mlflow_run(status="FAILED", termination_reason=reason) + mlflow_status = "FAILED" + termination_reason = type(e).__name__ or "unknown_exception" raise + finally: + # Best-effort finalization. Never mask the original exception. + try: + if args.rank == (args.world_size - 1) and mlflow_status is not None: + end_mlflow_run(status=mlflow_status, termination_reason=termination_reason) + except Exception: + pass + try: + # clean up torch pg resources on exit + if dist.is_initialized(): + dist.destroy_process_group() + except Exception: + pass def train( self, @@ -2038,9 +2053,9 @@ def get_e2e_base_metrics(): wandb_writer = get_wandb_writer() if wandb_writer: wandb_writer.finish() - mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + # Mark run as finished if exit code is 0; otherwise failed. + mlflow_status = "FINISHED" if exit_code == 0 else "FAILED" + end_mlflow_run(status=mlflow_status, termination_reason=f"exit_condition") ft_integration.shutdown() sys.exit(exit_code) From 06045f7d367a2090983685c99070658151f71a88 Mon Sep 17 00:00:00 2001 From: mvstrauss Date: Tue, 20 Jan 2026 07:50:30 +0000 Subject: [PATCH 3/3] chore(megatron): document best-effort teardown --- primus/modules/trainer/megatron/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 98fede09e..75e55aae8 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -1153,6 +1153,7 @@ def run(self, *args, **kwargs): if args.rank == (args.world_size - 1) and mlflow_status is not None: end_mlflow_run(status=mlflow_status, termination_reason=termination_reason) except Exception: + # Best-effort: MLflow finalization must never mask the original training error. pass try: @@ -1160,6 +1161,7 @@ def run(self, *args, **kwargs): if dist.is_initialized(): dist.destroy_process_group() except Exception: + # Best-effort cleanup: ignore teardown errors to avoid masking the original failure. pass def train(