diff --git a/primus/backends/megatron/training/global_vars.py b/primus/backends/megatron/training/global_vars.py index 0d37ca79b..cadad0683 100644 --- a/primus/backends/megatron/training/global_vars.py +++ b/primus/backends/megatron/training/global_vars.py @@ -85,6 +85,36 @@ def _set_mlflow_writer(args): _GLOBAL_MLFLOW_WRITER = mlflow +def end_mlflow_run(status="FINISHED", termination_reason=None): + """End MLflow run with specified status. + + Args: + status: MLflow run status - "FINISHED", "FAILED", or "KILLED". + termination_reason: Optional free-form string tag recorded as + "termination_reason" on the run. + """ + global _GLOBAL_MLFLOW_WRITER + + if _GLOBAL_MLFLOW_WRITER is None: + return + + try: + # Optionally attach a coarse termination reason tag for debugging. + if termination_reason is not None: + try: + _GLOBAL_MLFLOW_WRITER.set_tag("termination_reason", termination_reason) + except Exception: + # Ignore tagging failures; status update is more important. + pass + + _GLOBAL_MLFLOW_WRITER.end_run(status=status) + except Exception: + # Swallow MLflow/network errors to avoid masking the original failure. + pass + finally: + _GLOBAL_MLFLOW_WRITER = None + + def unset_global_variables(): """Unset global vars.""" diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index fe70aa729..d6d399346 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -144,6 +144,7 @@ from primus.backends.megatron.core.transformer.moe.moe_utils import track_moe_metrics from primus.backends.megatron.model_provider import primus_model_provider from primus.backends.megatron.training.global_vars import ( + end_mlflow_run, get_mlflow_writer, get_train_start_time, set_primus_global_variables, @@ -1015,122 +1016,153 @@ def run(self, *args, **kwargs): one_logger = get_one_logger() args = get_args() - if args.pp_warmup: - from .utils import pp_warmup + mlflow_status = None + termination_reason = None - log_rank_0( - "warmup on each rank in parallel to decrease " - "the first iter time, especially when pp degree is large" - ) - timers = get_timers() - timers("pp-warmup", log_level=0).start(barrier=True) - pp_warmup(args, self.config, self.model, self.optimizer) - timers("pp-warmup").stop() - timers.log(["pp-warmup"], barrier=True) - - process_non_loss_data_func = None - non_loss_data_func = None - if not args.skip_train: - log_rank_0("training ...") - - if args.dataloader_type == "cyclic" and args.retro_project_dir: - assert args.retro_cyclic_train_iters is not None - args.train_iters = args.retro_cyclic_train_iters - log_rank_0("retro cyclic train iters : %d" % args.train_iters) - - iteration = 0 - if args.do_train and args.train_iters > 0: - iteration, num_floating_point_operations_so_far = self.train( - self.forward_step, - self.model, - self.optimizer, - self.opt_param_scheduler, - self.train_data_iterator, - self.valid_data_iterator, - process_non_loss_data_func, - self.config, - self.checkpointing_context, - non_loss_data_func, + try: + if args.pp_warmup: + from .utils import pp_warmup + + log_rank_0( + "warmup on each rank in parallel to decrease " + "the first iter time, especially when pp degree is large" ) + timers = get_timers() + timers("pp-warmup", log_level=0).start(barrier=True) + pp_warmup(args, self.config, self.model, self.optimizer) + timers("pp-warmup").stop() + timers.log(["pp-warmup"], barrier=True) + + process_non_loss_data_func = None + non_loss_data_func = None + if not args.skip_train: + log_rank_0("training ...") + + if args.dataloader_type == "cyclic" and args.retro_project_dir: + assert args.retro_cyclic_train_iters is not None + args.train_iters = args.retro_cyclic_train_iters + log_rank_0("retro cyclic train iters : %d" % args.train_iters) + + iteration = 0 + if args.do_train and args.train_iters > 0: + iteration, num_floating_point_operations_so_far = self.train( + self.forward_step, + self.model, + self.optimizer, + self.opt_param_scheduler, + self.train_data_iterator, + self.valid_data_iterator, + process_non_loss_data_func, + self.config, + self.checkpointing_context, + non_loss_data_func, + ) - print_datetime("after training is done") + print_datetime("after training is done") - if ( - args.save - and iteration != 0 - and iteration % args.save_interval != 0 - and not args.disable_last_saving - ): - save_checkpoint( - iteration, - self.model, - self.optimizer, - self.opt_param_scheduler, - num_floating_point_operations_so_far, - self.checkpointing_context, - train_data_iterator=self.train_data_iterator, - preprocess_common_state_dict_fn=preprocess_common_state_dict, - ) + if ( + args.save + and iteration != 0 + and iteration % args.save_interval != 0 + and not args.disable_last_saving + ): + save_checkpoint( + iteration, + self.model, + self.optimizer, + self.opt_param_scheduler, + num_floating_point_operations_so_far, + self.checkpointing_context, + train_data_iterator=self.train_data_iterator, + preprocess_common_state_dict_fn=preprocess_common_state_dict, + ) - one_logger and one_logger.log_metrics( - {"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()} - ) + one_logger and one_logger.log_metrics( + {"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()} + ) - else: - log_rank_0("skipping training (--skip-train is on) ...") + else: + log_rank_0("skipping training (--skip-train is on) ...") - iteration = args.iteration + iteration = args.iteration - if args.do_valid: - prefix = f"iteration {iteration} on validation set" - evaluate_and_print_results( - prefix, - self.forward_step, - self.valid_data_iterator, - self.model, - iteration, - process_non_loss_data_func, - self.config, - verbose=True, - write_to_tensorboard=not args.skip_train, - non_loss_data_func=non_loss_data_func, - ) + if args.do_valid: + prefix = f"iteration {iteration} on validation set" + evaluate_and_print_results( + prefix, + self.forward_step, + self.valid_data_iterator, + self.model, + iteration, + process_non_loss_data_func, + self.config, + verbose=True, + write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func, + ) - if args.do_test: - prefix = f"iteration {iteration} on test set" - evaluate_and_print_results( - prefix, - self.forward_step, - self.test_data_iterator, - self.model, - iteration, - process_non_loss_data_func, - self.config, - verbose=True, - write_to_tensorboard=not args.skip_train, - non_loss_data_func=non_loss_data_func, - ) + if args.do_test: + prefix = f"iteration {iteration} on test set" + evaluate_and_print_results( + prefix, + self.forward_step, + self.test_data_iterator, + self.model, + iteration, + process_non_loss_data_func, + self.config, + verbose=True, + write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func, + ) - wandb_writer = get_wandb_writer() - if wandb_writer: - wandb_writer.finish() + wandb_writer = get_wandb_writer() + if wandb_writer: + wandb_writer.finish() - ft_integration.on_checkpointing_start() - maybe_finalize_async_save(blocking=True, terminate=True) - ft_integration.on_checkpointing_end(is_async_finalization=True) + ft_integration.on_checkpointing_start() + maybe_finalize_async_save(blocking=True, terminate=True) + ft_integration.on_checkpointing_end(is_async_finalization=True) - mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + mlflow_status = "FINISHED" + termination_reason = "clean_finish" - one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) + one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()}) - ft_integration.shutdown() - one_logger_utils.finish() + ft_integration.shutdown() + one_logger_utils.finish() + + except KeyboardInterrupt: + mlflow_status = "KILLED" + termination_reason = "keyboard_interrupt" + raise + except SystemExit as e: + # sys.exit() raises SystemExit (not an Exception). Preserve behavior, but + # still mark MLflow run with a meaningful terminal status in `finally`. + exit_code = e.code if isinstance(e.code, int) else 0 + mlflow_status = "FINISHED" if exit_code == 0 else "FAILED" + termination_reason = f"system_exit_{exit_code}" + raise + except Exception as e: + mlflow_status = "FAILED" + termination_reason = type(e).__name__ or "unknown_exception" + raise + finally: + # Best-effort finalization. Never mask the original exception. + try: + if args.rank == (args.world_size - 1) and mlflow_status is not None: + end_mlflow_run(status=mlflow_status, termination_reason=termination_reason) + except Exception: + # Best-effort: MLflow finalization must never mask the original training error. + pass - # clean up torch pg resources on exit - if dist.is_initialized(): - dist.destroy_process_group() + try: + # clean up torch pg resources on exit + if dist.is_initialized(): + dist.destroy_process_group() + except Exception: + # Best-effort cleanup: ignore teardown errors to avoid masking the original failure. + pass def train( self, @@ -1563,9 +1595,9 @@ def get_e2e_base_metrics(): wandb_writer = get_wandb_writer() if wandb_writer: wandb_writer.finish() - mlflow_writer = get_mlflow_writer() - if mlflow_writer: - mlflow_writer.end_run() + # Mark run as finished if exit code is 0; otherwise failed. + mlflow_status = "FINISHED" if exit_code == 0 else "FAILED" + end_mlflow_run(status=mlflow_status, termination_reason=f"exit_condition") ft_integration.shutdown() sys.exit(exit_code)