Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions primus/backends/megatron/training/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ def _set_mlflow_writer(args):
_GLOBAL_MLFLOW_WRITER = mlflow


def end_mlflow_run(status="FINISHED", termination_reason=None):
"""End MLflow run with specified status.

Args:
status: MLflow run status - "FINISHED", "FAILED", or "KILLED".
termination_reason: Optional free-form string tag recorded as
"termination_reason" on the run.
"""
global _GLOBAL_MLFLOW_WRITER

if _GLOBAL_MLFLOW_WRITER is None:
return

try:
# Optionally attach a coarse termination reason tag for debugging.
if termination_reason is not None:
try:
_GLOBAL_MLFLOW_WRITER.set_tag("termination_reason", termination_reason)
except Exception:
# Ignore tagging failures; status update is more important.
pass

_GLOBAL_MLFLOW_WRITER.end_run(status=status)
except Exception:
# Swallow MLflow/network errors to avoid masking the original failure.
pass
finally:
_GLOBAL_MLFLOW_WRITER = None


def unset_global_variables():
"""Unset global vars."""

Expand Down
240 changes: 136 additions & 104 deletions primus/modules/trainer/megatron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
from primus.backends.megatron.core.transformer.moe.moe_utils import track_moe_metrics
from primus.backends.megatron.model_provider import primus_model_provider
from primus.backends.megatron.training.global_vars import (
end_mlflow_run,
get_mlflow_writer,
get_train_start_time,
set_primus_global_variables,
Expand Down Expand Up @@ -1015,122 +1016,153 @@ def run(self, *args, **kwargs):
one_logger = get_one_logger()
args = get_args()

if args.pp_warmup:
from .utils import pp_warmup
mlflow_status = None
termination_reason = None

log_rank_0(
"warmup on each rank in parallel to decrease "
"the first iter time, especially when pp degree is large"
)
timers = get_timers()
timers("pp-warmup", log_level=0).start(barrier=True)
pp_warmup(args, self.config, self.model, self.optimizer)
timers("pp-warmup").stop()
timers.log(["pp-warmup"], barrier=True)

process_non_loss_data_func = None
non_loss_data_func = None
if not args.skip_train:
log_rank_0("training ...")

if args.dataloader_type == "cyclic" and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
log_rank_0("retro cyclic train iters : %d" % args.train_iters)

iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = self.train(
self.forward_step,
self.model,
self.optimizer,
self.opt_param_scheduler,
self.train_data_iterator,
self.valid_data_iterator,
process_non_loss_data_func,
self.config,
self.checkpointing_context,
non_loss_data_func,
try:
if args.pp_warmup:
from .utils import pp_warmup

log_rank_0(
"warmup on each rank in parallel to decrease "
"the first iter time, especially when pp degree is large"
)
timers = get_timers()
timers("pp-warmup", log_level=0).start(barrier=True)
pp_warmup(args, self.config, self.model, self.optimizer)
timers("pp-warmup").stop()
timers.log(["pp-warmup"], barrier=True)

process_non_loss_data_func = None
non_loss_data_func = None
if not args.skip_train:
log_rank_0("training ...")

if args.dataloader_type == "cyclic" and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
log_rank_0("retro cyclic train iters : %d" % args.train_iters)

iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = self.train(
self.forward_step,
self.model,
self.optimizer,
self.opt_param_scheduler,
self.train_data_iterator,
self.valid_data_iterator,
process_non_loss_data_func,
self.config,
self.checkpointing_context,
non_loss_data_func,
)

print_datetime("after training is done")
print_datetime("after training is done")

if (
args.save
and iteration != 0
and iteration % args.save_interval != 0
and not args.disable_last_saving
):
save_checkpoint(
iteration,
self.model,
self.optimizer,
self.opt_param_scheduler,
num_floating_point_operations_so_far,
self.checkpointing_context,
train_data_iterator=self.train_data_iterator,
preprocess_common_state_dict_fn=preprocess_common_state_dict,
)
if (
args.save
and iteration != 0
and iteration % args.save_interval != 0
and not args.disable_last_saving
):
save_checkpoint(
iteration,
self.model,
self.optimizer,
self.opt_param_scheduler,
num_floating_point_operations_so_far,
self.checkpointing_context,
train_data_iterator=self.train_data_iterator,
preprocess_common_state_dict_fn=preprocess_common_state_dict,
)

one_logger and one_logger.log_metrics(
{"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()}
)
one_logger and one_logger.log_metrics(
{"app_train_loop_finish_time": one_logger_utils.get_timestamp_in_ms()}
)

else:
log_rank_0("skipping training (--skip-train is on) ...")
else:
log_rank_0("skipping training (--skip-train is on) ...")

iteration = args.iteration
iteration = args.iteration

if args.do_valid:
prefix = f"iteration {iteration} on validation set"
evaluate_and_print_results(
prefix,
self.forward_step,
self.valid_data_iterator,
self.model,
iteration,
process_non_loss_data_func,
self.config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)
if args.do_valid:
prefix = f"iteration {iteration} on validation set"
evaluate_and_print_results(
prefix,
self.forward_step,
self.valid_data_iterator,
self.model,
iteration,
process_non_loss_data_func,
self.config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)

if args.do_test:
prefix = f"iteration {iteration} on test set"
evaluate_and_print_results(
prefix,
self.forward_step,
self.test_data_iterator,
self.model,
iteration,
process_non_loss_data_func,
self.config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)
if args.do_test:
prefix = f"iteration {iteration} on test set"
evaluate_and_print_results(
prefix,
self.forward_step,
self.test_data_iterator,
self.model,
iteration,
process_non_loss_data_func,
self.config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)

wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()

ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)

mlflow_writer = get_mlflow_writer()
if mlflow_writer:
mlflow_writer.end_run()
mlflow_status = "FINISHED"
termination_reason = "clean_finish"

one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()})
one_logger and one_logger.log_metrics({"app_finish_time": one_logger_utils.get_timestamp_in_ms()})

ft_integration.shutdown()
one_logger_utils.finish()
ft_integration.shutdown()
one_logger_utils.finish()

except KeyboardInterrupt:
mlflow_status = "KILLED"
termination_reason = "keyboard_interrupt"
raise
except SystemExit as e:
# sys.exit() raises SystemExit (not an Exception). Preserve behavior, but
# still mark MLflow run with a meaningful terminal status in `finally`.
exit_code = e.code if isinstance(e.code, int) else 0
mlflow_status = "FINISHED" if exit_code == 0 else "FAILED"
termination_reason = f"system_exit_{exit_code}"
raise
except Exception as e:
mlflow_status = "FAILED"
termination_reason = type(e).__name__ or "unknown_exception"
raise
finally:
# Best-effort finalization. Never mask the original exception.
try:
if args.rank == (args.world_size - 1) and mlflow_status is not None:
end_mlflow_run(status=mlflow_status, termination_reason=termination_reason)
except Exception:
# Best-effort: MLflow finalization must never mask the original training error.
pass

# clean up torch pg resources on exit
if dist.is_initialized():
dist.destroy_process_group()
try:
# clean up torch pg resources on exit
if dist.is_initialized():
dist.destroy_process_group()
except Exception:
# Best-effort cleanup: ignore teardown errors to avoid masking the original failure.
pass

def train(
self,
Expand Down Expand Up @@ -1563,9 +1595,9 @@ def get_e2e_base_metrics():
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
mlflow_writer = get_mlflow_writer()
if mlflow_writer:
mlflow_writer.end_run()
# Mark run as finished if exit code is 0; otherwise failed.
mlflow_status = "FINISHED" if exit_code == 0 else "FAILED"
end_mlflow_run(status=mlflow_status, termination_reason=f"exit_condition")
ft_integration.shutdown()
sys.exit(exit_code)

Expand Down