From 80e2d2672cc609bb592592df059a722538a7cd32 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Wed, 10 Dec 2025 00:34:53 +0000 Subject: [PATCH 01/23] initial commit --- .../configs/MI300X/mamba_370M-pretrain.yaml | 94 +++++++++++++++++++ .../MI300X/mamba_hybrid_2.8B-pretrain.yaml | 84 +++++++++++++++++ .../models/megatron/language_model.yaml | 1 + .../configs/models/megatron/mamba_1.4B.yaml | 15 +++ .../configs/models/megatron/mamba_370M.yaml | 15 +++ .../configs/models/megatron/mamba_base.yaml | 41 ++++++++ .../models/megatron/mamba_hybrid_2.8B.yaml | 29 ++++++ primus/core/utils/import_utils.py | 41 +++++--- .../trainer/lightmegatron/pre_trainer.py | 41 ++++++-- primus/modules/trainer/megatron/trainer.py | 8 +- 10 files changed, 344 insertions(+), 25 deletions(-) create mode 100644 examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml create mode 100644 examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml create mode 100644 primus/configs/models/megatron/mamba_1.4B.yaml create mode 100644 primus/configs/models/megatron/mamba_370M.yaml create mode 100644 primus/configs/models/megatron/mamba_base.yaml create mode 100644 primus/configs/models/megatron/mamba_hybrid_2.8B.yaml diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml new file mode 100644 index 000000000..f1a60cde0 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -0,0 +1,94 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_370M-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_370M.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Pretrain" + # disable_wandb: false + # disable_tensorboard: false + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 50 + micro_batch_size: 4 + global_batch_size: 256 + + seq_length: 2048 + max_position_embeddings: 2048 + + lr: 3.0e-4 + min_lr: 0.0 + lr_warmup_iters: 100 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # Mamba SSM parameters + is_hybrid_model: false + hybrid_attention_ratio: 0.0 + hybrid_mlp_ratio: 0.0 + mamba_state_dim: 16 + mamba_head_dim: 64 + mamba_num_groups: 8 + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + + # Turbo - may need to disable for Mamba if not supported + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + # Cross entropy flags + # cross_entropy_fusion_impl: "native" + # cross_entropy_loss_fusion: false + diff --git a/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml new file mode 100644 index 000000000..58d800401 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml @@ -0,0 +1,84 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_hybrid_2.8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_hybrid_2.8B.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Hybrid_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 128 + + seq_length: 4096 + max_position_embeddings: 4096 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # Hybrid Mamba+Attention parameters + is_hybrid_model: true + hybrid_attention_ratio: 0.125 + hybrid_mlp_ratio: 0.0 + mamba_state_dim: 16 + mamba_head_dim: 64 + mamba_num_groups: 8 + + # parallel + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: true + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch + + # Turbo - disable for Mamba layers, but attention layers may benefit + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index 8c6f24e48..3de9cbff5 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -8,6 +8,7 @@ extends: # model architecture use_legacy_models: false deprecated_use_mcore_models: false +model_type: gpt # gpt or mamba num_layers: 24 encoder_num_layers: null decoder_num_layers: null diff --git a/primus/configs/models/megatron/mamba_1.4B.yaml b/primus/configs/models/megatron/mamba_1.4B.yaml new file mode 100644 index 000000000..91e255443 --- /dev/null +++ b/primus/configs/models/megatron/mamba_1.4B.yaml @@ -0,0 +1,15 @@ +bases: + - mamba_base.yaml + +# Mamba 1.4B configuration + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 2048 +ffn_hidden_size: null + +max_position_embeddings: 2048 + diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml new file mode 100644 index 000000000..b10523b2e --- /dev/null +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -0,0 +1,15 @@ +bases: + - mamba_base.yaml + +# Mamba 370M configuration + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 1024 +ffn_hidden_size: null + +max_position_embeddings: 2048 + diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml new file mode 100644 index 000000000..f658c8371 --- /dev/null +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -0,0 +1,41 @@ +bases: + - language_model.yaml + +# Mamba-specific configuration +# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) +# must be set in the pretrain config overrides, not here + +model_type: mamba +use_legacy_models: false + +# Position embeddings - Mamba typically doesn't use position embeddings +position_embedding_type: rope +use_rotary_position_embeddings: false + +# Tokenizer (should be set in specific model configs) +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: null + +# Model architecture +num_layers: 24 +hidden_size: 1024 +ffn_hidden_size: null + +# Standard transformer settings that may be used by hybrid models +num_attention_heads: 16 +attention_dropout: 0.0 +hidden_dropout: 0.0 + +# Embeddings +untie_embeddings_and_output_weights: true + +# Other settings +apply_residual_connection_post_layernorm: false +add_bias_linear: false +swiglu: false + +# Normalization +norm_epsilon: 1.0e-5 + +# Initialization +init_method_std: 0.02 diff --git a/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml b/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml new file mode 100644 index 000000000..f2fd20cba --- /dev/null +++ b/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml @@ -0,0 +1,29 @@ +bases: + - mamba_base.yaml + +# Mamba 2.8B configuration with hybrid attention layers + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 64 +hidden_size: 2560 +ffn_hidden_size: 6827 # ~2.67x hidden_size + +# Attention parameters (for hybrid layers) +num_attention_heads: 32 +group_query_attention: true +num_query_groups: 8 + +# Hybrid configuration: override mamba_base defaults +hybrid_attention_ratio: 0.125 +is_hybrid_model: true + +# For hybrid models, position embeddings may be useful +position_embedding_type: rope +rotary_base: 10000 +rotary_percent: 1.0 + +max_position_embeddings: 4096 + diff --git a/primus/core/utils/import_utils.py b/primus/core/utils/import_utils.py index 2ccd8ebed..6e67de8a4 100644 --- a/primus/core/utils/import_utils.py +++ b/primus/core/utils/import_utils.py @@ -34,25 +34,40 @@ def lazy_import(paths, symbol, log_prefix="[Primus]"): raise ImportError(f"{log_prefix} {symbol} not found in any of: {paths}") -def get_model_provider(): +def get_model_provider(model_type="gpt"): """ - Resolve model_provider across Megatron versions. + Resolve model_provider across Megatron versions and model types. - - New: model_provider + gpt_builder + Args: + model_type (str): Type of model - 'gpt' or 'mamba'. Defaults to 'gpt'. + + - New: model_provider + gpt_builder/mamba_builder - Mid: model_provider only - - Old: pretrain_gpt.model_provider + - Old: pretrain_gpt.model_provider / pretrain_mamba.model_provider """ # Try to import model_provider - model_provider = lazy_import( - ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" - ) + if model_type == "mamba": + model_provider = lazy_import( + ["model_provider", "pretrain_mamba"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) + # Try to import mamba_builder (for Mamba models) + try: + mamba_builder = lazy_import(["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, mamba_builder) + except ImportError: + return model_provider + else: + # Default GPT behavior + model_provider = lazy_import( + ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) - # Try to import gpt_builder (only exists in newer versions) - try: - gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") - return partial(model_provider, gpt_builder) - except ImportError: - return model_provider + # Try to import gpt_builder (only exists in newer versions) + try: + gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, gpt_builder) + except ImportError: + return model_provider def get_custom_fsdp(): diff --git a/primus/modules/trainer/lightmegatron/pre_trainer.py b/primus/modules/trainer/lightmegatron/pre_trainer.py index 0a12460fa..973421933 100644 --- a/primus/modules/trainer/lightmegatron/pre_trainer.py +++ b/primus/modules/trainer/lightmegatron/pre_trainer.py @@ -38,15 +38,36 @@ def run(self, *args, **kwargs): from megatron.core.enums import ModelType from megatron.training import inprocess_restart, pretrain - from pretrain_gpt import forward_step, train_valid_test_datasets_provider + from megatron.training import get_args - train_valid_test_datasets_provider.is_distributed = True - wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + # Determine model type from config (gpt or mamba) + megatron_args = get_args() + model_type = getattr(megatron_args, 'model_type', 'gpt') + log_rank_0(f"Detected model_type: {model_type}") - wrapped_pretrain( - train_valid_test_datasets_provider, - get_model_provider(), - ModelType.encoder_or_decoder, - forward_step, - store=store, - ) + if model_type == 'mamba': + # Import from pretrain_mamba + from pretrain_mamba import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type='mamba'), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) + else: + # Default to GPT + from pretrain_gpt import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type='gpt'), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index fe70aa729..8e58f6be2 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -496,11 +496,15 @@ def update_primus_config( args.valid_data_path = None args.test_data_path = None + # Determine model type (gpt or mamba) + model_type = getattr(args, 'model_type', 'gpt') + log_rank_0(f"-detected model_type: {model_type}") + if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial(primus_model_provider, get_model_provider()) + self.model_provider = functools.partial(primus_model_provider, get_model_provider(model_type=model_type)) else: - self.model_provider = get_model_provider() + self.model_provider = get_model_provider(model_type=model_type) if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}") From d23d79f7ec7cfd03718b8152840abf4a3b7cd0fe Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Tue, 16 Dec 2025 14:05:15 +0000 Subject: [PATCH 02/23] set self.lr_warmup_steps < self.lr_decay_steps --- examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml index f1a60cde0..9557e7860 100644 --- a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -31,8 +31,8 @@ modules: lr: 3.0e-4 min_lr: 0.0 - lr_warmup_iters: 100 - lr_decay_iters: null + lr_warmup_iters: 50000 + lr_decay_iters: 73192188 lr_decay_style: cosine weight_decay: 0.1 adam_beta1: 0.9 From 3381850e1f1824d810d8afec82f01a3f7d8e7b1f Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Thu, 18 Dec 2025 15:58:30 +0000 Subject: [PATCH 03/23] unwrap model to remove loss_mask parameter --- .../modules/trainer/megatron/pre_trainer.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/primus/modules/trainer/megatron/pre_trainer.py b/primus/modules/trainer/megatron/pre_trainer.py index 1c8539dd5..efb197f7e 100644 --- a/primus/modules/trainer/megatron/pre_trainer.py +++ b/primus/modules/trainer/megatron/pre_trainer.py @@ -242,6 +242,20 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals assert ( args.overlap_moe_expert_parallel_comm ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + + # Schedule plan building is only supported for GPT models + # Check if this is a Mamba model + unwrapped_model = model + while hasattr(unwrapped_model, 'module'): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if 'Mamba' in model_class_name: + raise NotImplementedError( + "Schedule plan building is not supported for Mamba models. " + "Please disable overlap_moe_expert_parallel_comm for Mamba." + ) + if args.patch_moe_overlap: assert ( not args.delay_wgrad_compute @@ -267,8 +281,23 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals ) return schedule_plan, partial(self.loss_func, loss_mask) else: - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask - ) + # Check if model supports loss_mask parameter + # MambaModel doesn't accept loss_mask, but GPTModel does + # Unwrap the model to get the actual model class + unwrapped_model = model + while hasattr(unwrapped_model, 'module'): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if 'Mamba' in model_class_name: + # MambaModel doesn't accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels + ) + else: + # GPTModel and other models accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) return output_tensor, partial(self.loss_func, loss_mask) From 277f3e1c5276187d6edec460e5197873adf1a67c Mon Sep 17 00:00:00 2001 From: Mingyu Yang Date: Tue, 13 Jan 2026 19:07:59 +0000 Subject: [PATCH 04/23] add zebra-llama (hybrid mla mamba model) support --- .../configs/MI300X/mamba_370M-pretrain.yaml | 10 +- .../MI300X/zebra_llama_8B-pretrain.yaml | 70 +++ .../megatron/core/models/hybrid/__init__.py | 11 + .../core/models/hybrid/hybrid_block.py | 399 ++++++++++++++++++ .../hybrid/hybrid_mamba_mla_layer_specs.py | 178 ++++++++ .../models/megatron/language_model.yaml | 15 + .../configs/models/megatron/mamba_370M.yaml | 6 +- .../configs/models/megatron/mamba_base.yaml | 9 +- .../models/megatron/zebra_llama_8B.yaml | 41 ++ .../modules/megatron/trainer_base.yaml | 23 +- 10 files changed, 731 insertions(+), 31 deletions(-) create mode 100644 examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml create mode 100644 primus/backends/megatron/core/models/hybrid/__init__.py create mode 100644 primus/backends/megatron/core/models/hybrid/hybrid_block.py create mode 100644 primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py create mode 100644 primus/configs/models/megatron/zebra_llama_8B.yaml diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml index 9557e7860..d5bb62e71 100644 --- a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -46,15 +46,7 @@ modules: # Tokenizer tokenizer_type: HuggingFaceTokenizer - tokenizer_model: meta-llama/Llama-3.2-1B - - # Mamba SSM parameters - is_hybrid_model: false - hybrid_attention_ratio: 0.0 - hybrid_mlp_ratio: 0.0 - mamba_state_dim: 16 - mamba_head_dim: 64 - mamba_num_groups: 8 + tokenizer_model: EleutherAI/gpt-neox-20b # parallel tensor_model_parallel_size: 1 diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..ff74b6cf4 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 128 + + seq_length: 4096 + max_position_embeddings: 4096 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: true + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch + diff --git a/primus/backends/megatron/core/models/hybrid/__init__.py b/primus/backends/megatron/core/models/hybrid/__init__.py new file mode 100644 index 000000000..799e3b0ce --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Hybrid Mamba+MLA layer specifications for Megatron-LM.""" + +from .hybrid_mamba_mla_layer_specs import hybrid_inference_stack_spec, hybrid_stack_spec + +__all__ = ["hybrid_stack_spec", "hybrid_inference_stack_spec"] + diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_block.py b/primus/backends/megatron/core/models/hybrid/hybrid_block.py new file mode 100644 index 000000000..c39957fe6 --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_block.py @@ -0,0 +1,399 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers +from megatron.core.transformer import TransformerConfig + +# CudaGraphScope is not available in older Megatron versions +try: + from megatron.core.transformer.enums import CudaGraphScope + HAS_CUDA_GRAPH_SCOPE = True +except ImportError: + CudaGraphScope = None + HAS_CUDA_GRAPH_SCOPE = False + +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor + + +@dataclass +class HybridStackSubmodules: + """ + A class for the module specs for the MambaStack. + """ + + mamba_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + mlp_layer: Union[ModuleSpec, type] = IdentityOp + moe_layer: Union[ModuleSpec, type] = IdentityOp + + +class HybridStack(MegatronModule): + """ + Constructor for the HybridStack class. + + Args: + config (TransformerConfig): the model configuration + submodules (MambaStackSubmodules): the submodules for the stack + residual_in_fp32 (bool, optional): whether to do residual connections + in fp32. Defaults to False. + pre_process (bool, optional): whether to include an embedding layer. + Defaults to True. + hybrid_attention_ratio (float, optional): the target ratio of attention layers to + total layers. Defaults to 0.0. + hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total + layers. Defaults to 0.0. + hybrid_override_pattern (str, optional): the hybrid layer pattern to override + with. Defaults to None. + post_layer_norm (bool, optional): whether to include a final layer norm. + Defaults to True. + post_process (bool, optional): whether to include an output layer. + Defaults to True. + device (optional): the device to use. Defaults to None. + dtype (optional): the data type to use. Defaults to None. + pg_collection (ProcessGroupCollection): the required model communication + process groups to use. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: HybridStackSubmodules, + residual_in_fp32=False, + pre_process: bool = True, + hybrid_attention_ratio: float = 0.0, + hybrid_mlp_ratio: float = 0.0, + hybrid_override_pattern: str = None, + post_layer_norm: bool = True, + post_process: bool = True, + device=None, + dtype=None, + pg_collection: ProcessGroupCollection = None, + ) -> None: + super().__init__(config=config) + self.residual_in_fp32 = residual_in_fp32 + self.pre_process = pre_process + self.post_layer_norm = post_layer_norm + self.post_process = post_process + + assert pg_collection is not None, "pg_collection must be provided for MambaStack" + + self.pp_group = pg_collection.pp + self.tp_group = pg_collection.tp + + # Required for pipeline parallel schedules + self.input_tensor = None + + self.hybrid_attention_ratio = hybrid_attention_ratio + self.hybrid_mlp_ratio = hybrid_mlp_ratio + self.hybrid_override_pattern = hybrid_override_pattern + + # Customized layer allocation + # hybrid_mlp_ratio is not used in this hybrid stack. + # It is by default to be always followed by mamba or mla (i.e., mamba + MLP or MLA + MLP) + # By setting hybrid_attention_ratio, attention layers are by default to be distributed uniformly. + self.layer_type_list = self.allocate_layers( + self.config.num_layers, + self.hybrid_attention_ratio, + ) + + pp_layer_offset = 0 + if self.pp_group.size() > 1: + pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( + self.layer_type_list + ) + + print(f"layer_type_list: {self.layer_type_list}") + + self.layers = nn.ModuleList() + for i, layer_type in enumerate(self.layer_type_list): + fp8_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + with fp8_init_context: + if layer_type == LayerSymbols.MAMBA: + layer = build_module( + submodules.mamba_layer, + config=self.config, + residual_in_fp32=residual_in_fp32, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.ATTENTION: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.attention_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MLP: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.mlp_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MOE: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.moe_layer, config=self.config, layer_number=i + 1 + ) + else: + assert False, "unexpected layer_type" + self.layers.append(layer) + + # Required for activation recomputation + self.num_layers_per_pipeline_rank = len(self.layers) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_norm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + def allocate_layers(self, num_layers, hybrid_attention_ratio): + layer_type_list = [] + num_attention_layers = int(num_layers//2 * hybrid_attention_ratio) + num_mamba_layers = num_layers//2 - num_attention_layers + num_mamba_per_attention_layer = num_mamba_layers // num_attention_layers + + if hybrid_attention_ratio <= 0.5: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] * num_mamba_per_attention_layer + layer_type_list += base_block*num_attention_layers + layer_type_list += [LayerSymbols.MAMBA, LayerSymbols.MLP] * (num_mamba_layers % num_attention_layers) + else: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] + layer_type_list += [LayerSymbols.ATTENTION, LayerSymbols.MLP] * (num_attention_layers - num_mamba_layers) + layer_type_list += base_block*num_mamba_layers + return layer_type_list + + def _select_layers_for_pipeline_parallel(self, layer_type_list): + num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() + + assert self.config.virtual_pipeline_model_parallel_size is None, ( + "The Mamba hybrid model does not currently support " + "virtual/interleaved pipeline parallelism" + ) + + offset = self.pp_group.rank() * num_layers_per_pipeline_rank + selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + + return offset, selected_list + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int]]]: + """ + Returns the Mamba conv and ssm states shapes per input sequence + if this block contains Mamba layers (this may not be the case with PP > 1). + """ + for layer_type, layer in zip(self.layer_type_list, self.layers): + if layer_type == LayerSymbols.MAMBA: + return layer.mamba_state_shapes_per_request() + return None + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Tensor, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Forward function of the MambaStack class. + + It either returns the Loss values if labels are given or the + final hidden units + + Args: + hidden_states (Union[Tensor, WrappedTensor]): the input tensor. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): the attention mask. + inference_context (BaseInferenceContext): the inference parameters. + rotary_pos_emb (Tensor, optional): the rotary positional embeddings. + Defaults to None. + Returns: + Tensor: the output tensor. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if inference_context and inference_context.is_static_batching(): + # NOTE(bnorick): match BaseInferenceContext attributes for + # mamba_ssm.utils.generation.BaseInferenceContext, + # this hack supports eval + inference_context.max_seqlen = inference_context.max_sequence_length + inference_context.seqlen_offset = inference_context.sequence_len_offset + + if ( + ( + ( + HAS_CUDA_GRAPH_SCOPE + and self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + ) + or self.config.flash_decode + ) + and inference_context + and inference_context.is_static_batching() + and not self.training + ): + current_batch_size = hidden_states.shape[1] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device='cuda', + ) + else: + sequence_len_offset = None + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with outer_fp8_context: + for layer in self.layers: + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + with inner_fp8_context: + if isinstance(layer, TransformerLayer): + hidden_states, _ = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + ) + else: # MambaLayer + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + ) + + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_norm(hidden_states) + + # Ensure that the tensor passed between pipeline parallel stages is + # viewless. See related notes in TransformerBlock and TransformerLayer + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + + return hidden_states + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Optional[tuple] = None, + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """ + Returns a sharded state dictionary for the current object. + + This function constructs a sharded state dictionary by iterating over the layers + in the current object, computing the sharded state dictionary for each layer, + and combining the results into a single dictionary. + + Parameters: + prefix (str): The prefix to use for the state dictionary keys. + sharded_offsets (tuple): The sharded offsets to use for the state dictionary. + metadata (dict): Additional metadata to use when computing the sharded state dictionary. + + Returns: + dict: The sharded state dictionary for the current object. + """ + + sharded_state_dict = {} + layer_prefix = f'{prefix}layers.' + + for local_layer_idx, layer in enumerate(self.layers): + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = ( + f'{layer_prefix}{local_layer_idx}.' # module list index in MambaBlock + ) + + sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default( + module, + f'{prefix}{name}.', + sharded_offsets, + metadata, + tp_group=self.tp_group, + ) + ) + + return sharded_state_dict \ No newline at end of file diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py new file mode 100644 index 000000000..4777da8a0 --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, +) +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.ssm.mlp_layer import MLPLayer +# Import HybridStack from relative path +from primus.backends.megatron.core.models.hybrid.hybrid_block import ( + HybridStack, + HybridStackSubmodules, +) +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) + +# Inference layers may not be available in older Megatron versions +# They're only used in hybrid_inference_stack_spec, not the training spec +try: + from megatron.core.tensor_parallel import ( + InferenceLayerNormColumnParallelLinear, + InferenceRowParallelLinear, + ) + HAS_INFERENCE_LAYERS = True +except ImportError: + # Fallback to regular layers for inference spec + InferenceLayerNormColumnParallelLinear = TELayerNormColumnParallelLinear + InferenceRowParallelLinear = TERowParallelLinear + HAS_INFERENCE_LAYERS = False + +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +moe = get_moe_module_spec( + use_te=True, + num_experts=8, # Can be any positive integer (must not be None). + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, +) + +hybrid_stack_spec = ModuleSpec( + module=HybridStack, + submodules=HybridStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + params={ + "expand": 1, + "d_conv": 4, + }, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + attention_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TELayerNormColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + moe_layer=ModuleSpec( + # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + ), + ), + ), +) + +hybrid_inference_stack_spec = ModuleSpec( + module=HybridStack, + submodules=HybridStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=InferenceLayerNormColumnParallelLinear, + out_proj=InferenceRowParallelLinear, + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py (with MLP removed) + # Using the TE spec because we had problems getting the non-TE spec + # working + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=InferenceLayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=InferenceRowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py + # Using the TE spec because we had problems getting the non-TE spec + # working + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=InferenceLayerNormColumnParallelLinear, + linear_fc2=InferenceRowParallelLinear, + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + moe_layer=ModuleSpec( + # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + ), + ), + ), +) \ No newline at end of file diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index 3de9cbff5..a360cc695 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -100,6 +100,21 @@ rotary_scaling_factor: 1.0 # float mscale: 1.0 # float mscale_all_dim: 1.0 # float +# Mamba layer configuration +mamba_state_dim: 128 +mamba_head_dim: 64 +mamba_num_groups: 8 +mamba_num_heads: null +mamba_expand: 2 +mamba_d_conv: 4 +disable_mamba_mem_eff_path: false + +# Hybrid model configuration +is_hybrid_model: false # bool +hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +hybrid_override_pattern: null # str + # MTP mtp_num_layers: null # int mtp_loss_scaling_factor: 0.1 # float diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml index b10523b2e..99656a98b 100644 --- a/primus/configs/models/megatron/mamba_370M.yaml +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -2,7 +2,6 @@ bases: - mamba_base.yaml # Mamba 370M configuration - tokenizer_type: GPT2BPETokenizer vocab_size: 50257 @@ -10,6 +9,7 @@ vocab_size: 50257 num_layers: 48 hidden_size: 1024 ffn_hidden_size: null - -max_position_embeddings: 2048 +mamba_state_dim: 16 +mamba_head_dim: 64 +mamba_num_groups: 8 diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml index f658c8371..acda1d8b0 100644 --- a/primus/configs/models/megatron/mamba_base.yaml +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -16,18 +16,13 @@ use_rotary_position_embeddings: false tokenizer_type: HuggingFaceTokenizer tokenizer_model: null -# Model architecture -num_layers: 24 -hidden_size: 1024 -ffn_hidden_size: null - # Standard transformer settings that may be used by hybrid models -num_attention_heads: 16 +is_hybrid_model: false attention_dropout: 0.0 hidden_dropout: 0.0 # Embeddings -untie_embeddings_and_output_weights: true +untie_embeddings_and_output_weights: false # Other settings apply_residual_connection_post_layernorm: false diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml new file mode 100644 index 000000000..c97fc9de9 --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -0,0 +1,41 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 64 +hidden_size: 4096 +ffn_hidden_size: 14436 + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 2048 # Query LoRA rank +kv_lora_rank: 160 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +max_position_embeddings: 131072 diff --git a/primus/configs/modules/megatron/trainer_base.yaml b/primus/configs/modules/megatron/trainer_base.yaml index 76be0c7ba..ba89bd7d1 100755 --- a/primus/configs/modules/megatron/trainer_base.yaml +++ b/primus/configs/modules/megatron/trainer_base.yaml @@ -305,17 +305,17 @@ rerun_mode: disabled # str: 'disabled', 'validate_results', 'report_stats' # Experimental features enable_experimental: false -# Hybrid model configuration -hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] -hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] -hybrid_override_pattern: null # str - -# Mamba layer configuration -mamba_state_dim: 128 -mamba_head_dim: 64 -mamba_num_groups: 8 -mamba_num_heads: null -disable_mamba_mem_eff_path: false +# # Hybrid model configuration +# hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_override_pattern: null # str + +# # Mamba layer configuration +# mamba_state_dim: 128 +# mamba_head_dim: 64 +# mamba_num_groups: 8 +# mamba_num_heads: null +# disable_mamba_mem_eff_path: false # Args of precision-aware optimizer use_precision_aware_optimizer: false @@ -405,7 +405,6 @@ indexer_log_interval: 1000 enable_ft_package: false calc_ft_timeouts: false run_workload_inspector_server: false -is_hybrid_model: false heterogeneous_layers_config_path: null heterogeneous_layers_config_encoded_json: null From 11b22c6f84f623eec2a7eadab256d0c71ffe7f6a Mon Sep 17 00:00:00 2001 From: Mingyu Yang Date: Thu, 22 Jan 2026 22:21:47 +0000 Subject: [PATCH 05/23] add Zebra-Llama 3B configurations --- .../MI300X/zebra_llama_3B-pretrain.yaml | 70 +++++++++++++++++++ .../models/megatron/zebra_llama_3B.yaml | 41 +++++++++++ 2 files changed, 111 insertions(+) create mode 100644 examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml create mode 100644 primus/configs/models/megatron/zebra_llama_3B.yaml diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..e41cefcb5 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 128 + + seq_length: 4096 + max_position_embeddings: 4096 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: true + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch + diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml new file mode 100644 index 000000000..da1aef953 --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -0,0 +1,41 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 56 +hidden_size: 3072 +ffn_hidden_size: 8192 + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 24 +q_lora_rank: 1536 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +max_position_embeddings: 131072 From 1dec95e0ac7165b2e19d4d784cbcb673f1aa6311 Mon Sep 17 00:00:00 2001 From: Mingyu Yang Date: Fri, 23 Jan 2026 20:57:51 +0000 Subject: [PATCH 06/23] add Zebra-Llama 1B configs and remove unused configs --- ...rain.yaml => zebra_llama_1B-pretrain.yaml} | 26 +++--------- .../models/megatron/mamba_hybrid_2.8B.yaml | 29 ------------- .../models/megatron/zebra_llama_1B.yaml | 41 +++++++++++++++++++ 3 files changed, 47 insertions(+), 49 deletions(-) rename examples/megatron/configs/MI300X/{mamba_hybrid_2.8B-pretrain.yaml => zebra_llama_1B-pretrain.yaml} (66%) delete mode 100644 primus/configs/models/megatron/mamba_hybrid_2.8B.yaml create mode 100644 primus/configs/models/megatron/zebra_llama_1B.yaml diff --git a/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml similarity index 66% rename from examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml rename to examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml index 58d800401..28f84dd03 100644 --- a/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -1,6 +1,6 @@ work_group: ${PRIMUS_TEAM:amd} user_name: ${PRIMUS_USER:root} -exp_name: ${PRIMUS_EXP_NAME:mamba_hybrid_2.8B-pretrain} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} workspace: ${PRIMUS_WORKSPACE:./output} modules: @@ -9,10 +9,10 @@ modules: config: pre_trainer.yaml # model to run - model: mamba_hybrid_2.8B.yaml + model: zebra_llama_1B.yaml overrides: # log - wandb_project: "Primus_Mamba_Hybrid_Pretrain" + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" stderr_sink_level: DEBUG eval_iters: 0 @@ -36,26 +36,17 @@ modules: adam_beta1: 0.9 adam_beta2: 0.95 eod_mask_loss: true - init_method_std: 0.02 - norm_epsilon: 1.0e-5 # Mamba-specific: must provide spec - spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] # Tokenizer tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-1B - # Hybrid Mamba+Attention parameters - is_hybrid_model: true - hybrid_attention_ratio: 0.125 - hybrid_mlp_ratio: 0.0 - mamba_state_dim: 16 - mamba_head_dim: 64 - mamba_num_groups: 8 - # parallel - tensor_model_parallel_size: 2 + tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 expert_model_parallel_size: 1 overlap_grad_reduce: true @@ -77,8 +68,3 @@ modules: disable_last_saving: true ckpt_format: torch - # Turbo - disable for Mamba layers, but attention layers may benefit - enable_primus_turbo: false - use_turbo_attention: false - use_turbo_grouped_mlp: false - diff --git a/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml b/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml deleted file mode 100644 index f2fd20cba..000000000 --- a/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml +++ /dev/null @@ -1,29 +0,0 @@ -bases: - - mamba_base.yaml - -# Mamba 2.8B configuration with hybrid attention layers - -tokenizer_type: GPT2BPETokenizer -vocab_size: 50257 - -# Model size parameters -num_layers: 64 -hidden_size: 2560 -ffn_hidden_size: 6827 # ~2.67x hidden_size - -# Attention parameters (for hybrid layers) -num_attention_heads: 32 -group_query_attention: true -num_query_groups: 8 - -# Hybrid configuration: override mamba_base defaults -hybrid_attention_ratio: 0.125 -is_hybrid_model: true - -# For hybrid models, position embeddings may be useful -position_embedding_type: rope -rotary_base: 10000 -rotary_percent: 1.0 - -max_position_embeddings: 4096 - diff --git a/primus/configs/models/megatron/zebra_llama_1B.yaml b/primus/configs/models/megatron/zebra_llama_1B.yaml new file mode 100644 index 000000000..ff54e98d0 --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_1B.yaml @@ -0,0 +1,41 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 32 +hidden_size: 2048 +ffn_hidden_size: 8192 + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 64 +mamba_head_dim: 64 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 1344 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 32 # Query-Key head dimension +qk_pos_emb_head_dim: 32 # Positional embedding head dimension +v_head_dim: 64 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +max_position_embeddings: 131072 From 34508db369b25ac9a000c8f64226227130e13386 Mon Sep 17 00:00:00 2001 From: Mingyu Yang Date: Fri, 23 Jan 2026 21:01:59 +0000 Subject: [PATCH 07/23] remove unused configs --- primus/configs/models/megatron/mamba_1.4B.yaml | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 primus/configs/models/megatron/mamba_1.4B.yaml diff --git a/primus/configs/models/megatron/mamba_1.4B.yaml b/primus/configs/models/megatron/mamba_1.4B.yaml deleted file mode 100644 index 91e255443..000000000 --- a/primus/configs/models/megatron/mamba_1.4B.yaml +++ /dev/null @@ -1,15 +0,0 @@ -bases: - - mamba_base.yaml - -# Mamba 1.4B configuration - -tokenizer_type: GPT2BPETokenizer -vocab_size: 50257 - -# Model size parameters -num_layers: 48 -hidden_size: 2048 -ffn_hidden_size: null - -max_position_embeddings: 2048 - From 2f3ab4956d16a9deab1dc756821d37a07062c505 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Thu, 29 Jan 2026 02:07:40 +0000 Subject: [PATCH 08/23] Set submodule mamba to track enable-primus-hybrid-models branch --- .gitmodules | 4 ++++ third_party/mamba | 1 + 2 files changed, 5 insertions(+) create mode 160000 third_party/mamba diff --git a/.gitmodules b/.gitmodules index 15658588b..0af8e5ed1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,3 +15,7 @@ [submodule "third_party/Megatron-Bridge"] path = third_party/Megatron-Bridge url = https://github.com/NVIDIA-NeMo/Megatron-Bridge.git +[submodule "third_party/mamba"] + path = third_party/mamba + url = https://github.com/AndreasKaratzas/mamba.git + branch = enable-primus-hybrid-models diff --git a/third_party/mamba b/third_party/mamba new file mode 160000 index 000000000..4d67c534d --- /dev/null +++ b/third_party/mamba @@ -0,0 +1 @@ +Subproject commit 4d67c534d10b7299194ede55b36718cc7a1dd472 From 159e441db4d511ee8933ee445b08f4267b2b3707 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Tue, 3 Feb 2026 23:27:43 +0000 Subject: [PATCH 09/23] set moe_layer_freq default value of 1 --- primus/modules/trainer/megatron/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 8e58f6be2..edd04fdb6 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -483,8 +483,10 @@ def update_primus_config( if args.iterations_to_skip is None: args.iterations_to_skip = [] - # support moe_freq_type - if isinstance(args.moe_layer_freq, str): + # support moe_freq_type - ensure moe_layer_freq has a default value + if not hasattr(args, 'moe_layer_freq'): + args.moe_layer_freq = 1 + elif isinstance(args.moe_layer_freq, str): try: args.moe_layer_freq = eval(args.moe_layer_freq) except Exception: From f798e77953aabcb655b580edf882a94f6ce86e27 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Tue, 3 Feb 2026 23:46:45 +0000 Subject: [PATCH 10/23] set final_logit_softcapping and router_logit_softcapping to null --- primus/modules/trainer/megatron/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index edd04fdb6..e09c4b477 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -502,6 +502,12 @@ def update_primus_config( model_type = getattr(args, 'model_type', 'gpt') log_rank_0(f"-detected model_type: {model_type}") + # Ensure softcapping attributes have default values + if not hasattr(args, 'final_logit_softcapping'): + args.final_logit_softcapping = None + if not hasattr(args, 'router_logit_softcapping'): + args.router_logit_softcapping = None + if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") self.model_provider = functools.partial(primus_model_provider, get_model_provider(model_type=model_type)) From d7f4faf82335f53efb1c125ad0443ca029d11623 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Wed, 4 Feb 2026 00:50:29 +0000 Subject: [PATCH 11/23] use mamba builder --- primus/configs/models/megatron/mamba_370M.yaml | 2 ++ primus/configs/models/megatron/zebra_llama_1B.yaml | 1 + primus/configs/models/megatron/zebra_llama_3B.yaml | 1 + primus/configs/models/megatron/zebra_llama_8B.yaml | 1 + primus/modules/trainer/megatron/trainer.py | 9 +++++++-- 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml index 99656a98b..52ccad1e2 100644 --- a/primus/configs/models/megatron/mamba_370M.yaml +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -2,12 +2,14 @@ bases: - mamba_base.yaml # Mamba 370M configuration +model_type: mamba # CRITICAL: Mamba models must use mamba model type tokenizer_type: GPT2BPETokenizer vocab_size: 50257 # Model size parameters num_layers: 48 hidden_size: 1024 +num_attention_heads: 16 # Required by Megatron validation, even for pure Mamba models ffn_hidden_size: null mamba_state_dim: 16 mamba_head_dim: 64 diff --git a/primus/configs/models/megatron/zebra_llama_1B.yaml b/primus/configs/models/megatron/zebra_llama_1B.yaml index ff54e98d0..152834260 100644 --- a/primus/configs/models/megatron/zebra_llama_1B.yaml +++ b/primus/configs/models/megatron/zebra_llama_1B.yaml @@ -2,6 +2,7 @@ bases: - mamba_base.yaml # Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-1B diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml index da1aef953..a57487732 100644 --- a/primus/configs/models/megatron/zebra_llama_3B.yaml +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -2,6 +2,7 @@ bases: - mamba_base.yaml # Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-1B diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml index c97fc9de9..1ebfa55a0 100644 --- a/primus/configs/models/megatron/zebra_llama_8B.yaml +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -2,6 +2,7 @@ bases: - mamba_base.yaml # Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-1B diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index e09c4b477..818fa96ee 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -502,7 +502,7 @@ def update_primus_config( model_type = getattr(args, 'model_type', 'gpt') log_rank_0(f"-detected model_type: {model_type}") - # Ensure softcapping attributes have default values + # Ensure required attributes have safe defaults if missing from config if not hasattr(args, 'final_logit_softcapping'): args.final_logit_softcapping = None if not hasattr(args, 'router_logit_softcapping'): @@ -512,7 +512,10 @@ def update_primus_config( log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") self.model_provider = functools.partial(primus_model_provider, get_model_provider(model_type=model_type)) else: - self.model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-getting model provider for model_type={model_type}") + model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-model_provider: {model_provider}") + self.model_provider = model_provider if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}") @@ -879,6 +882,8 @@ def setup_model_and_optimizer( log_rank_0(f"use te backend...") log_rank_0(f"-run get_model") + log_rank_0(f"-model_provider_func: {model_provider_func}") + log_rank_0(f"-model_type: {model_type}") model = get_model(model_provider_func, model_type) log_rank_0(model) # get_megatron_optimizer will use the ddp_config From d5bdbb8e40dfc343224bd1502a44b99a8bf93ab5 Mon Sep 17 00:00:00 2001 From: Mingyu Yang Date: Thu, 5 Feb 2026 00:27:08 +0000 Subject: [PATCH 12/23] adjust zebra-llama architecture and training --- .../MI300X/zebra_llama_1B-pretrain.yaml | 7 ++- .../MI300X/zebra_llama_3B-pretrain.yaml | 9 +-- .../MI300X/zebra_llama_8B-pretrain.yaml | 9 +-- .../megatron/core/models/hybrid/__init__.py | 4 +- .../hybrid/hybrid_mamba_mla_layer_specs.py | 62 ------------------- .../models/megatron/language_model.yaml | 1 + .../models/megatron/zebra_llama_3B.yaml | 3 +- .../models/megatron/zebra_llama_8B.yaml | 3 +- 8 files changed, 21 insertions(+), 77 deletions(-) diff --git a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml index 28f84dd03..1aaa9219e 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -22,10 +22,11 @@ modules: train_iters: 100 micro_batch_size: 2 - global_batch_size: 128 + global_batch_size: 16 - seq_length: 4096 - max_position_embeddings: 4096 + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 lr: 2.0e-4 min_lr: 2.0e-5 diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml index e41cefcb5..c70800a6f 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -22,10 +22,11 @@ modules: train_iters: 100 micro_batch_size: 2 - global_batch_size: 128 + global_batch_size: 16 - seq_length: 4096 - max_position_embeddings: 4096 + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 lr: 2.0e-4 min_lr: 2.0e-5 @@ -43,7 +44,7 @@ modules: # Tokenizer tokenizer_type: HuggingFaceTokenizer - tokenizer_model: meta-llama/Llama-3.2-1B + tokenizer_model: meta-llama/Llama-3.2-3B # parallel tensor_model_parallel_size: 1 diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml index ff74b6cf4..e477d36ee 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -22,10 +22,11 @@ modules: train_iters: 100 micro_batch_size: 2 - global_batch_size: 128 + global_batch_size: 16 - seq_length: 4096 - max_position_embeddings: 4096 + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 lr: 2.0e-4 min_lr: 2.0e-5 @@ -43,7 +44,7 @@ modules: # Tokenizer tokenizer_type: HuggingFaceTokenizer - tokenizer_model: meta-llama/Llama-3.2-1B + tokenizer_model: meta-llama/Llama-3.1-8B # parallel tensor_model_parallel_size: 1 diff --git a/primus/backends/megatron/core/models/hybrid/__init__.py b/primus/backends/megatron/core/models/hybrid/__init__.py index 799e3b0ce..0cd249162 100644 --- a/primus/backends/megatron/core/models/hybrid/__init__.py +++ b/primus/backends/megatron/core/models/hybrid/__init__.py @@ -5,7 +5,7 @@ """Hybrid Mamba+MLA layer specifications for Megatron-LM.""" -from .hybrid_mamba_mla_layer_specs import hybrid_inference_stack_spec, hybrid_stack_spec +from .hybrid_mamba_mla_layer_specs import hybrid_stack_spec -__all__ = ["hybrid_stack_spec", "hybrid_inference_stack_spec"] +__all__ = ["hybrid_stack_spec"] diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py index 4777da8a0..b3c1ce90f 100644 --- a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py +++ b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py @@ -96,7 +96,6 @@ mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( @@ -114,65 +113,4 @@ ), ), ), -) - -hybrid_inference_stack_spec = ModuleSpec( - module=HybridStack, - submodules=HybridStackSubmodules( - mamba_layer=ModuleSpec( - module=MambaLayer, - submodules=MambaLayerSubmodules( - mixer=ModuleSpec( - module=MambaMixer, - submodules=MambaMixerSubmodules( - in_proj=InferenceLayerNormColumnParallelLinear, - out_proj=InferenceRowParallelLinear, - ), - ), - mamba_bda=get_bias_dropout_add, - ), - ), - # Started with spec from gpt_layer_specs.py (with MLP removed) - # Using the TE spec because we had problems getting the non-TE spec - # working - attention_layer=ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=InferenceLayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=InferenceRowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - ), - ), - # Started with spec from gpt_layer_specs.py - # Using the TE spec because we had problems getting the non-TE spec - # working - mlp_layer=ModuleSpec( - module=MLPLayer, - submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=InferenceLayerNormColumnParallelLinear, - linear_fc2=InferenceRowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ), - moe_layer=ModuleSpec( - # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add - ), - ), - ), ) \ No newline at end of file diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index a360cc695..50296107e 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -22,6 +22,7 @@ num_query_groups: null add_position_embedding: false position_embedding_type: learned_absolute max_position_embeddings: null +original_max_position_embeddings: null untie_embeddings_and_output_weights: true ffn_hidden_size: null diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml index a57487732..0e82f2667 100644 --- a/primus/configs/models/megatron/zebra_llama_3B.yaml +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -10,6 +10,7 @@ tokenizer_model: meta-llama/Llama-3.2-1B num_layers: 56 hidden_size: 3072 ffn_hidden_size: 8192 +normalization: "RMSNorm" # Mamba parameters is_hybrid_model: true @@ -39,4 +40,4 @@ rotary_base: 500000 position_embedding_type: none add_position_embedding: true use_rotary_position_embeddings: false -max_position_embeddings: 131072 +original_max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml index 1ebfa55a0..01eb6befa 100644 --- a/primus/configs/models/megatron/zebra_llama_8B.yaml +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -10,6 +10,7 @@ tokenizer_model: meta-llama/Llama-3.2-1B num_layers: 64 hidden_size: 4096 ffn_hidden_size: 14436 +normalization: "RMSNorm" # Mamba parameters is_hybrid_model: true @@ -39,4 +40,4 @@ rotary_base: 500000 position_embedding_type: none add_position_embedding: true use_rotary_position_embeddings: false -max_position_embeddings: 131072 +original_max_position_embeddings: 131072 From 6bc8f600ba3e891fcc9d73ee44d7d06f6d1b4ad6 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Wed, 4 Feb 2026 20:32:33 -0800 Subject: [PATCH 13/23] Potential fix for pull request finding 'Unused local variable' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- primus/backends/megatron/core/models/hybrid/hybrid_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_block.py b/primus/backends/megatron/core/models/hybrid/hybrid_block.py index c39957fe6..28be92a47 100644 --- a/primus/backends/megatron/core/models/hybrid/hybrid_block.py +++ b/primus/backends/megatron/core/models/hybrid/hybrid_block.py @@ -334,11 +334,11 @@ def forward( # Ensure that the tensor passed between pipeline parallel stages is # viewless. See related notes in TransformerBlock and TransformerLayer - output = make_viewless_tensor( + return make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - return hidden_states + def sharded_state_dict( self, From 7e43595fd29c65a511661a66f7bbc2f9fdffac85 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Fri, 6 Feb 2026 01:23:04 +0000 Subject: [PATCH 14/23] code lint with pre-commit --- .../configs/MI300X/mamba_370M-pretrain.yaml | 3 +- .../MI300X/zebra_llama_1B-pretrain.yaml | 5 +- .../MI300X/zebra_llama_3B-pretrain.yaml | 5 +- .../MI300X/zebra_llama_8B-pretrain.yaml | 5 +- .../megatron/core/models/hybrid/__init__.py | 1 - .../core/models/hybrid/hybrid_block.py | 64 ++++++++++--------- .../hybrid/hybrid_mamba_mla_layer_specs.py | 25 ++++---- .../configs/models/megatron/mamba_370M.yaml | 1 - .../configs/models/megatron/mamba_base.yaml | 2 +- .../models/megatron/zebra_llama_1B.yaml | 6 +- .../models/megatron/zebra_llama_3B.yaml | 6 +- .../models/megatron/zebra_llama_8B.yaml | 6 +- primus/core/utils/import_utils.py | 4 +- .../trainer/lightmegatron/pre_trainer.py | 13 ++-- .../modules/trainer/megatron/pre_trainer.py | 20 +++--- primus/modules/trainer/megatron/trainer.py | 12 ++-- 16 files changed, 91 insertions(+), 87 deletions(-) diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml index d5bb62e71..469913761 100644 --- a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -43,7 +43,7 @@ modules: # Mamba-specific: must provide spec spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] - + # Tokenizer tokenizer_type: HuggingFaceTokenizer tokenizer_model: EleutherAI/gpt-neox-20b @@ -83,4 +83,3 @@ modules: # Cross entropy flags # cross_entropy_fusion_impl: "native" # cross_entropy_loss_fusion: false - diff --git a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml index 1aaa9219e..2fa9d8fe5 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -41,11 +41,11 @@ modules: # Mamba-specific: must provide spec # Use custom hybrid Mamba+MLA spec spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] - + # Tokenizer tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-1B - + # parallel tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 @@ -68,4 +68,3 @@ modules: save_interval: 10000 disable_last_saving: true ckpt_format: torch - diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml index c70800a6f..05b0290ab 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -41,11 +41,11 @@ modules: # Mamba-specific: must provide spec # Use custom hybrid Mamba+MLA spec spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] - + # Tokenizer tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.2-3B - + # parallel tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 @@ -68,4 +68,3 @@ modules: save_interval: 10000 disable_last_saving: true ckpt_format: torch - diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml index e477d36ee..e5c02b91e 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -41,11 +41,11 @@ modules: # Mamba-specific: must provide spec # Use custom hybrid Mamba+MLA spec spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] - + # Tokenizer tokenizer_type: HuggingFaceTokenizer tokenizer_model: meta-llama/Llama-3.1-8B - + # parallel tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 @@ -68,4 +68,3 @@ modules: save_interval: 10000 disable_last_saving: true ckpt_format: torch - diff --git a/primus/backends/megatron/core/models/hybrid/__init__.py b/primus/backends/megatron/core/models/hybrid/__init__.py index 0cd249162..05d2d673a 100644 --- a/primus/backends/megatron/core/models/hybrid/__init__.py +++ b/primus/backends/megatron/core/models/hybrid/__init__.py @@ -8,4 +8,3 @@ from .hybrid_mamba_mla_layer_specs import hybrid_stack_spec __all__ = ["hybrid_stack_spec"] - diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_block.py b/primus/backends/megatron/core/models/hybrid/hybrid_block.py index 28be92a47..3c4d7def8 100644 --- a/primus/backends/megatron/core/models/hybrid/hybrid_block.py +++ b/primus/backends/megatron/core/models/hybrid/hybrid_block.py @@ -10,8 +10,6 @@ from typing import Optional, Tuple, Union import torch -from torch import Tensor, nn - from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.enums import Fp8Recipe @@ -20,12 +18,13 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols -from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers from megatron.core.transformer import TransformerConfig +from torch import Tensor, nn # CudaGraphScope is not available in older Megatron versions try: from megatron.core.transformer.enums import CudaGraphScope + HAS_CUDA_GRAPH_SCOPE = True except ImportError: CudaGraphScope = None @@ -36,7 +35,11 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import sharded_state_dict_default -from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + make_viewless_tensor, +) @dataclass @@ -110,9 +113,9 @@ def __init__( self.hybrid_attention_ratio = hybrid_attention_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern - + # Customized layer allocation - # hybrid_mlp_ratio is not used in this hybrid stack. + # hybrid_mlp_ratio is not used in this hybrid stack. # It is by default to be always followed by mamba or mla (i.e., mamba + MLP or MLA + MLP) # By setting hybrid_attention_ratio, attention layers are by default to be distributed uniformly. self.layer_type_list = self.allocate_layers( @@ -125,7 +128,7 @@ def __init__( pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( self.layer_type_list ) - + print(f"layer_type_list: {self.layer_type_list}") self.layers = nn.ModuleList() @@ -158,9 +161,7 @@ def __init__( ) elif layer_type == LayerSymbols.MOE: # Transformer layers apply their own pp_layer_offset - layer = build_module( - submodules.moe_layer, config=self.config, layer_number=i + 1 - ) + layer = build_module(submodules.moe_layer, config=self.config, layer_number=i + 1) else: assert False, "unexpected layer_type" self.layers.append(layer) @@ -175,28 +176,35 @@ def __init__( hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) + def allocate_layers(self, num_layers, hybrid_attention_ratio): layer_type_list = [] - num_attention_layers = int(num_layers//2 * hybrid_attention_ratio) - num_mamba_layers = num_layers//2 - num_attention_layers + num_attention_layers = int(num_layers // 2 * hybrid_attention_ratio) + num_mamba_layers = num_layers // 2 - num_attention_layers num_mamba_per_attention_layer = num_mamba_layers // num_attention_layers if hybrid_attention_ratio <= 0.5: - base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] * num_mamba_per_attention_layer - layer_type_list += base_block*num_attention_layers - layer_type_list += [LayerSymbols.MAMBA, LayerSymbols.MLP] * (num_mamba_layers % num_attention_layers) + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [ + LayerSymbols.MAMBA, + LayerSymbols.MLP, + ] * num_mamba_per_attention_layer + layer_type_list += base_block * num_attention_layers + layer_type_list += [LayerSymbols.MAMBA, LayerSymbols.MLP] * ( + num_mamba_layers % num_attention_layers + ) else: base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] - layer_type_list += [LayerSymbols.ATTENTION, LayerSymbols.MLP] * (num_attention_layers - num_mamba_layers) - layer_type_list += base_block*num_mamba_layers + layer_type_list += [LayerSymbols.ATTENTION, LayerSymbols.MLP] * ( + num_attention_layers - num_mamba_layers + ) + layer_type_list += base_block * num_mamba_layers return layer_type_list def _select_layers_for_pipeline_parallel(self, layer_type_list): num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() assert self.config.virtual_pipeline_model_parallel_size is None, ( - "The Mamba hybrid model does not currently support " - "virtual/interleaved pipeline parallelism" + "The Mamba hybrid model does not currently support " "virtual/interleaved pipeline parallelism" ) offset = self.pp_group.rank() * num_layers_per_pipeline_rank @@ -285,7 +293,7 @@ def forward( sequence_len_offset = torch.tensor( [inference_context.sequence_len_offset] * current_batch_size, dtype=torch.int32, - device='cuda', + device="cuda", ) else: sequence_len_offset = None @@ -338,11 +346,9 @@ def forward( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - - def sharded_state_dict( self, - prefix: str = '', + prefix: str = "", sharded_offsets: Optional[tuple] = None, metadata: Optional[dict] = None, ) -> ShardedStateDict: @@ -363,16 +369,14 @@ def sharded_state_dict( """ sharded_state_dict = {} - layer_prefix = f'{prefix}layers.' + layer_prefix = f"{prefix}layers." for local_layer_idx, layer in enumerate(self.layers): global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 - state_dict_prefix = ( - f'{layer_prefix}{local_layer_idx}.' # module list index in MambaBlock - ) + state_dict_prefix = f"{layer_prefix}{local_layer_idx}." # module list index in MambaBlock - sharded_prefix = f'{layer_prefix}{global_layer_offset}.' + sharded_prefix = f"{layer_prefix}{global_layer_offset}." sharded_pp_offset = [] layer_sharded_state_dict = layer.sharded_state_dict( @@ -389,11 +393,11 @@ def sharded_state_dict( sharded_state_dict.update( sharded_state_dict_default( module, - f'{prefix}{name}.', + f"{prefix}{name}.", sharded_offsets, metadata, tp_group=self.tp_group, ) ) - return sharded_state_dict \ No newline at end of file + return sharded_state_dict diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py index b3c1ce90f..cb801809d 100644 --- a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py +++ b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py @@ -8,22 +8,22 @@ TENorm, TERowParallelLinear, ) -from megatron.core.transformer.identity_op import IdentityOp from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec -from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mlp_layer import MLPLayer +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) + # Import HybridStack from relative path from primus.backends.megatron.core.models.hybrid.hybrid_block import ( HybridStack, HybridStackSubmodules, ) -from megatron.core.transformer.multi_latent_attention import ( - MLASelfAttention, - MLASelfAttentionSubmodules, -) # Inference layers may not be available in older Megatron versions # They're only used in hybrid_inference_stack_spec, not the training spec @@ -32,6 +32,7 @@ InferenceLayerNormColumnParallelLinear, InferenceRowParallelLinear, ) + HAS_INFERENCE_LAYERS = True except ImportError: # Fallback to regular layers for inference spec @@ -39,11 +40,13 @@ InferenceRowParallelLinear = TERowParallelLinear HAS_INFERENCE_LAYERS = False -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, +) moe = get_moe_module_spec( use_te=True, @@ -60,7 +63,7 @@ submodules=MambaLayerSubmodules( mixer=ModuleSpec( module=MambaMixer, - params={ + params={ "expand": 1, "d_conv": 4, }, @@ -71,7 +74,7 @@ mamba_bda=get_bias_dropout_add, ), ), - attention_layer = ModuleSpec( + attention_layer=ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( input_layernorm=TENorm, @@ -113,4 +116,4 @@ ), ), ), -) \ No newline at end of file +) diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml index 52ccad1e2..6665da3af 100644 --- a/primus/configs/models/megatron/mamba_370M.yaml +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -14,4 +14,3 @@ ffn_hidden_size: null mamba_state_dim: 16 mamba_head_dim: 64 mamba_num_groups: 8 - diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml index acda1d8b0..d52fe6db2 100644 --- a/primus/configs/models/megatron/mamba_base.yaml +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -2,7 +2,7 @@ bases: - language_model.yaml # Mamba-specific configuration -# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) +# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) # must be set in the pretrain config overrides, not here model_type: mamba diff --git a/primus/configs/models/megatron/zebra_llama_1B.yaml b/primus/configs/models/megatron/zebra_llama_1B.yaml index 152834260..d1afa9531 100644 --- a/primus/configs/models/megatron/zebra_llama_1B.yaml +++ b/primus/configs/models/megatron/zebra_llama_1B.yaml @@ -9,7 +9,7 @@ tokenizer_model: meta-llama/Llama-3.2-1B # Model size parameters num_layers: 32 hidden_size: 2048 -ffn_hidden_size: 8192 +ffn_hidden_size: 8192 # Mamba parameters is_hybrid_model: true @@ -25,7 +25,7 @@ swiglu: true num_query_groups: null multi_latent_attention: true num_attention_heads: 32 -q_lora_rank: 1344 # Query LoRA rank +q_lora_rank: 1344 # Query LoRA rank kv_lora_rank: 128 # Key-Value LoRA rank qk_head_dim: 32 # Query-Key head dimension qk_pos_emb_head_dim: 32 # Positional embedding head dimension @@ -34,7 +34,7 @@ rotary_scaling_factor: 1.0 mscale: 1.0 mscale_all_dim: 1.0 -# MLA uses its own internal positional encoding +# MLA uses its own internal positional encoding rotary_base: 500000 position_embedding_type: none add_position_embedding: true diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml index 0e82f2667..23090841f 100644 --- a/primus/configs/models/megatron/zebra_llama_3B.yaml +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -9,7 +9,7 @@ tokenizer_model: meta-llama/Llama-3.2-1B # Model size parameters num_layers: 56 hidden_size: 3072 -ffn_hidden_size: 8192 +ffn_hidden_size: 8192 normalization: "RMSNorm" # Mamba parameters @@ -26,7 +26,7 @@ swiglu: true num_query_groups: null multi_latent_attention: true num_attention_heads: 24 -q_lora_rank: 1536 # Query LoRA rank +q_lora_rank: 1536 # Query LoRA rank kv_lora_rank: 128 # Key-Value LoRA rank qk_head_dim: 64 # Query-Key head dimension qk_pos_emb_head_dim: 64 # Positional embedding head dimension @@ -35,7 +35,7 @@ rotary_scaling_factor: 1.0 mscale: 1.0 mscale_all_dim: 1.0 -# MLA uses its own internal positional encoding +# MLA uses its own internal positional encoding rotary_base: 500000 position_embedding_type: none add_position_embedding: true diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml index 01eb6befa..0237a652d 100644 --- a/primus/configs/models/megatron/zebra_llama_8B.yaml +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -9,7 +9,7 @@ tokenizer_model: meta-llama/Llama-3.2-1B # Model size parameters num_layers: 64 hidden_size: 4096 -ffn_hidden_size: 14436 +ffn_hidden_size: 14436 normalization: "RMSNorm" # Mamba parameters @@ -26,7 +26,7 @@ swiglu: true num_query_groups: null multi_latent_attention: true num_attention_heads: 32 -q_lora_rank: 2048 # Query LoRA rank +q_lora_rank: 2048 # Query LoRA rank kv_lora_rank: 160 # Key-Value LoRA rank qk_head_dim: 64 # Query-Key head dimension qk_pos_emb_head_dim: 64 # Positional embedding head dimension @@ -35,7 +35,7 @@ rotary_scaling_factor: 1.0 mscale: 1.0 mscale_all_dim: 1.0 -# MLA uses its own internal positional encoding +# MLA uses its own internal positional encoding rotary_base: 500000 position_embedding_type: none add_position_embedding: true diff --git a/primus/core/utils/import_utils.py b/primus/core/utils/import_utils.py index 6e67de8a4..34c2de16d 100644 --- a/primus/core/utils/import_utils.py +++ b/primus/core/utils/import_utils.py @@ -52,7 +52,9 @@ def get_model_provider(model_type="gpt"): ) # Try to import mamba_builder (for Mamba models) try: - mamba_builder = lazy_import(["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]") + mamba_builder = lazy_import( + ["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]" + ) return partial(model_provider, mamba_builder) except ImportError: return model_provider diff --git a/primus/modules/trainer/lightmegatron/pre_trainer.py b/primus/modules/trainer/lightmegatron/pre_trainer.py index 973421933..eafa75d5a 100644 --- a/primus/modules/trainer/lightmegatron/pre_trainer.py +++ b/primus/modules/trainer/lightmegatron/pre_trainer.py @@ -37,23 +37,23 @@ def run(self, *args, **kwargs): log_rank_0("run light-megatron") from megatron.core.enums import ModelType - from megatron.training import inprocess_restart, pretrain - from megatron.training import get_args + from megatron.training import get_args, inprocess_restart, pretrain # Determine model type from config (gpt or mamba) megatron_args = get_args() - model_type = getattr(megatron_args, 'model_type', 'gpt') + model_type = getattr(megatron_args, "model_type", "gpt") log_rank_0(f"Detected model_type: {model_type}") - if model_type == 'mamba': + if model_type == "mamba": # Import from pretrain_mamba from pretrain_mamba import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) wrapped_pretrain( train_valid_test_datasets_provider, - get_model_provider(model_type='mamba'), + get_model_provider(model_type="mamba"), ModelType.encoder_or_decoder, forward_step, store=store, @@ -61,12 +61,13 @@ def run(self, *args, **kwargs): else: # Default to GPT from pretrain_gpt import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) wrapped_pretrain( train_valid_test_datasets_provider, - get_model_provider(model_type='gpt'), + get_model_provider(model_type="gpt"), ModelType.encoder_or_decoder, forward_step, store=store, diff --git a/primus/modules/trainer/megatron/pre_trainer.py b/primus/modules/trainer/megatron/pre_trainer.py index efb197f7e..d34788b04 100644 --- a/primus/modules/trainer/megatron/pre_trainer.py +++ b/primus/modules/trainer/megatron/pre_trainer.py @@ -242,20 +242,20 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals assert ( args.overlap_moe_expert_parallel_comm ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" - + # Schedule plan building is only supported for GPT models # Check if this is a Mamba model unwrapped_model = model - while hasattr(unwrapped_model, 'module'): + while hasattr(unwrapped_model, "module"): unwrapped_model = unwrapped_model.module model_class_name = unwrapped_model.__class__.__name__ - - if 'Mamba' in model_class_name: + + if "Mamba" in model_class_name: raise NotImplementedError( "Schedule plan building is not supported for Mamba models. " "Please disable overlap_moe_expert_parallel_comm for Mamba." ) - + if args.patch_moe_overlap: assert ( not args.delay_wgrad_compute @@ -285,15 +285,13 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals # MambaModel doesn't accept loss_mask, but GPTModel does # Unwrap the model to get the actual model class unwrapped_model = model - while hasattr(unwrapped_model, 'module'): + while hasattr(unwrapped_model, "module"): unwrapped_model = unwrapped_model.module model_class_name = unwrapped_model.__class__.__name__ - - if 'Mamba' in model_class_name: + + if "Mamba" in model_class_name: # MambaModel doesn't accept loss_mask parameter - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels - ) + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) else: # GPTModel and other models accept loss_mask parameter output_tensor = model( diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 818fa96ee..e26311337 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -484,7 +484,7 @@ def update_primus_config( args.iterations_to_skip = [] # support moe_freq_type - ensure moe_layer_freq has a default value - if not hasattr(args, 'moe_layer_freq'): + if not hasattr(args, "moe_layer_freq"): args.moe_layer_freq = 1 elif isinstance(args.moe_layer_freq, str): try: @@ -499,18 +499,20 @@ def update_primus_config( args.test_data_path = None # Determine model type (gpt or mamba) - model_type = getattr(args, 'model_type', 'gpt') + model_type = getattr(args, "model_type", "gpt") log_rank_0(f"-detected model_type: {model_type}") # Ensure required attributes have safe defaults if missing from config - if not hasattr(args, 'final_logit_softcapping'): + if not hasattr(args, "final_logit_softcapping"): args.final_logit_softcapping = None - if not hasattr(args, 'router_logit_softcapping'): + if not hasattr(args, "router_logit_softcapping"): args.router_logit_softcapping = None if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial(primus_model_provider, get_model_provider(model_type=model_type)) + self.model_provider = functools.partial( + primus_model_provider, get_model_provider(model_type=model_type) + ) else: log_rank_0(f"-getting model provider for model_type={model_type}") model_provider = get_model_provider(model_type=model_type) From a99a28f9a72dbae6310d3303a997b4eb8db663f9 Mon Sep 17 00:00:00 2001 From: Kailash Gogineni Date: Thu, 19 Feb 2026 07:21:32 -0800 Subject: [PATCH 15/23] [Docs] & [Feature]: Add Post-Training Documentation and Update Qwen3_32B Configs for MI300X & MI355X (#556) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit YF: Only SFT related config and Doc changes, bypassing unit CI tests ## Summary This PR introduces post-training documentation and updates Qwen3 32B model configuration files to support AMD MI300X and MI355X accelerators. --- ## Changes ### 📘 Documentation - **Added `posttraining.md`** - New comprehensive guide for post-training workflows - Covers setup instructions, configuration details, and usage examples - **Updated `docs/README.md`** - Added a new section referencing post-training documentation - Improved documentation organization and navigation --- ### ⚙️ Configuration Updates - **Updated Qwen3_32B model YAML configs** - Added/modified configurations optimized for: - MI300X - MI355X - Adjusted parameters for compatibility and stable execution --- ## Validation - Verified updated configs load and execute successfully on MI300X and MI355X environments - Confirmed documentation links and structure render correctly --- ## Checklist - [x] Added `posttraining.md` - [x] Updated `docs/README.md` - [x] Modified Qwen3_32B YAML configs - [x] Verified changes locally --- docs/README.md | 1 + docs/posttraining.md | 452 ++++++++++++++++++ .../MI300X/qwen3_32b_lora_posttrain.yaml | 58 +++ .../MI300X/qwen3_32b_sft_posttrain.yaml | 58 +++ .../MI355X/qwen3_32b_lora_posttrain.yaml | 6 +- .../MI355X/qwen3_32b_sft_posttrain.yaml | 6 +- .../modules/megatron_bridge/sft_trainer.yaml | 2 +- 7 files changed, 576 insertions(+), 7 deletions(-) create mode 100644 docs/posttraining.md create mode 100644 examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml create mode 100644 examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml diff --git a/docs/README.md b/docs/README.md index 45e28f26e..f2d22ad42 100644 --- a/docs/README.md +++ b/docs/README.md @@ -24,6 +24,7 @@ Guides for common workflows and features: In-depth technical documentation: +- **[Post-Training Guide](./posttraining.md)** - Fine-tuning with SFT and LoRA using Primus CLI - **[Performance Projection](./projection.md)** - Project training performance to multi-node configurations - **[Preflight](./preflight.md)** - Cluster diagnostics (host/GPU/network info + perf tests) - **[Benchmark Suite](./benchmark.md)** - GEMM, RCCL, end-to-end benchmarks and profiling diff --git a/docs/posttraining.md b/docs/posttraining.md new file mode 100644 index 000000000..27eb5ca01 --- /dev/null +++ b/docs/posttraining.md @@ -0,0 +1,452 @@ +# 🎓 Post-Training with Primus + +This guide demonstrates how to perform post-training (fine-tuning) using **Megatron Bridge** within the **Primus** framework. It covers both **Supervised Fine-Tuning (SFT)** and **Low-Rank Adaptation (LoRA)** methods for customizing pre-trained models. + +--- + +## 📚 Table of Contents + +- [🎓 Post-Training with Primus](#-post-training-with-primus) + - [📚 Table of Contents](#-table-of-contents) + - [🎯 Overview](#-overview) + - [⚙️ Supported Backends](#️-supported-backends) + - [🔧 Post-Training Methods](#-post-training-methods) + - [🚀 Quick Start](#-quick-start) + - [Prerequisites](#prerequisites) + - [Basic Usage](#basic-usage) + - [📝 Configuration Examples](#-configuration-examples) + - [Supervised Fine-Tuning (SFT)](#supervised-fine-tuning-sft) + - [LoRA Fine-Tuning](#lora-fine-tuning) + - [🖥️ Single Node Training](#️-single-node-training) + - [Direct Mode](#direct-mode) + - [Container Mode](#container-mode) + - [📊 Hardware-Specific Configurations](#-hardware-specific-configurations) + - [MI300X Configurations](#mi300x-configurations) + - [MI355X Configurations](#mi355x-configurations) + - [🎨 Customizing Training Parameters](#-customizing-training-parameters) + - [💡 Best Practices](#-best-practices) + - [🔍 Troubleshooting](#-troubleshooting) + +--- + +## 🎯 Overview + +Post-training (fine-tuning) allows you to adapt pre-trained foundation models to specific tasks or domains. Primus supports two primary fine-tuning approaches: + +- **Supervised Fine-Tuning (SFT)**: Full fine-tuning that updates all model parameters +- **LoRA (Low-Rank Adaptation)**: Parameter-efficient fine-tuning that only trains lightweight adapter modules + +--- + +## ⚙️ Supported Backends + +Post-training in Primus uses the **Megatron Bridge** backend: + +| Backend | Description | +| --------------- | --------------------------------------------------------------- | +| Megatron Bridge | Bridge implementation for fine-tuning Megatron-based models | + +--- + +## 🔧 Post-Training Methods + +| Method | Memory Usage | Training Speed | Use Case | +| ------ | ------------ | -------------- | ------------------------------------- | +| **SFT** | High | Slower | Maximum performance, full adaptation | +| **LoRA** | Low | Faster | Resource-efficient, quick iteration | + +**Key Differences:** +- **SFT** updates all model parameters, requiring more memory and compute +- **LoRA** trains only low-rank adapter matrices, significantly reducing resource requirements + +--- + +## 🚀 Quick Start + +### Prerequisites + +- AMD ROCm drivers (≥ 7.0) +- Docker (≥ 24.0) with ROCm support (recommended) +- AMD Instinct GPUs (MI300X, MI355X, etc.) +- Pre-trained model checkpoint (optional, for continued training) + +```bash +# Quick verification +rocm-smi && docker --version +``` + +### Basic Usage + +The general command structure for post-training: + +```bash +./runner/primus-cli train posttrain --config +``` + +**Example commands:** + +```bash +# SFT with direct mode +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml + +# LoRA with direct mode +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml +``` + +--- + +## 📝 Configuration Examples + +### Supervised Fine-Tuning (SFT) + +Full fine-tuning configuration example for **Qwen3 32B** on **MI355X**: + +```yaml +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:qwen3_32b_sft_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + model: qwen3_32b.yaml + + overrides: + # Parallelism configuration + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + sequence_parallel: false + + # Fine-tuning method + peft: "none" # Full fine-tuning + + # Training configuration + train_iters: 200 + global_batch_size: 8 + micro_batch_size: 1 + seq_length: 8192 + + # Optimizer configuration + finetune_lr: 5.0e-6 + min_lr: 0.0 + lr_warmup_iters: 50 + + # Precision + precision_config: bf16_mixed +``` + +**Configuration location:** `examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml` + +### LoRA Fine-Tuning + +Parameter-efficient fine-tuning configuration for **Qwen3 32B** on **MI355X**: + +```yaml +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:qwen3_32b_lora_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + model: qwen3_32b.yaml + + overrides: + # Parallelism configuration + tensor_model_parallel_size: 1 # LoRA requires less parallelism + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + sequence_parallel: false + + # Fine-tuning method + peft: lora # LoRA fine-tuning + + # Training configuration + train_iters: 200 + global_batch_size: 32 + micro_batch_size: 4 + seq_length: 8192 + + # Optimizer configuration + finetune_lr: 1.0e-4 # Higher LR for LoRA + min_lr: 0.0 + lr_warmup_iters: 50 + + # Precision + precision_config: bf16_mixed + + # Recompute configuration + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 1 +``` + +**Configuration location:** `examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml` + +--- + +## 🖥️ Single Node Training + +### Direct Mode + +Best for local development or when running directly on bare metal with ROCm installed. + +**SFT Example:** +```bash +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml +``` + +**LoRA Example:** +```bash +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml +``` + +**MI300X Examples:** +```bash +# SFT on MI300X +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml + +# LoRA on MI300X +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml +``` + +### Container Mode + +Recommended for environment isolation and dependency management. + +**Pull Docker image:** +```bash +docker pull docker.io/rocm/primus:latest +``` + +**SFT Example:** +```bash +./runner/primus-cli container --image rocm/primus:latest \ + train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml +``` + +**LoRA Example:** +```bash +./runner/primus-cli container --image rocm/primus:latest \ + train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml +``` + +--- + +## 📊 Hardware-Specific Configurations + +### MI300X Configurations + +Available configurations for AMD Instinct MI300X GPUs: + +| Model | Method | Config File | TP | GBS | MBS | Seq Len | +| ---------- | ------ | ------------------------------------------- | -- | --- | --- | ------- | +| Qwen3 32B | SFT | `MI300X/qwen3_32b_sft_posttrain.yaml` | 2 | 8 | 2 | 8192 | +| Qwen3 32B | LoRA | `MI300X/qwen3_32b_lora_posttrain.yaml` | 1 | 32 | 2 | 8192 | + +**Example:** +```bash +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml +``` + +### MI355X Configurations + +Available configurations for AMD Instinct MI355X GPUs: + +| Model | Method | Config File | TP | GBS | MBS | Seq Len | +| ------------ | ------ | ------------------------------------------- | -- | --- | --- | ------- | +| Qwen3 32B | SFT | `MI355X/qwen3_32b_sft_posttrain.yaml` | 1 | 8 | 1 | 8192 | +| Qwen3 32B | LoRA | `MI355X/qwen3_32b_lora_posttrain.yaml` | 1 | 32 | 4 | 8192 | + +**Legend:** +- **TP**: Tensor Parallelism Size +- **GBS**: Global Batch Size +- **MBS**: Micro Batch Size (per GPU) +- **Seq Len**: Sequence Length + +**Example:** +```bash +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml +``` + +--- + +## 🎨 Customizing Training Parameters + +Key parameters you can customize in the YAML configuration: + +### Parallelism Settings +```yaml +tensor_model_parallel_size: 1 # Number of GPUs for tensor parallelism (1-8) +pipeline_model_parallel_size: 1 # Number of GPUs for pipeline parallelism +context_parallel_size: 1 # Context parallelism for long sequences +sequence_parallel: false # Enable sequence parallelism +``` + +### Training Hyperparameters +```yaml +train_iters: 200 # Total training iterations +global_batch_size: 8 # Global batch size (8-32 depending on config) +micro_batch_size: 1 # Batch size per GPU (1-4 depending on config) +seq_length: 2048 # Sequence length (2048-8192 depending on model) +eval_interval: 30 # Evaluate every N iterations +save_interval: 50 # Save checkpoint every N iterations +``` + +### Learning Rate Configuration +```yaml +finetune_lr: 1.0e-4 # Initial learning rate +min_lr: 0.0 # Minimum learning rate +lr_warmup_iters: 50 # Number of warmup iterations +lr_decay_iters: null # Learning rate decay iterations +``` + +### Fine-Tuning Method +```yaml +peft: lora # Options: "lora" or "none" (for full SFT) +packed_sequence: false # Enable packed sequences for efficiency +``` + +### Precision Configuration +```yaml +precision_config: bf16_mixed # Options: bf16_mixed, fp16_mixed, fp32 +``` + +### Memory Optimization +```yaml +recompute_granularity: full # Options: full, selective, null +recompute_method: uniform # Recompute strategy +recompute_num_layers: 1 # Number of layers to recompute +``` + +--- + +## 💡 Best Practices + +### Choosing Between SFT and LoRA + +**Use SFT when:** +- You need maximum model performance +- You have sufficient GPU memory +- Training time is not critical +- You want full model adaptation + +**Use LoRA when:** +- GPU memory is limited +- You need fast iteration cycles +- Training multiple task-specific adapters +- Parameter efficiency is important + +### Parallelism Configuration + +**For SFT:** +- Use higher `tensor_model_parallel_size` for large models (e.g., TP=8 for 70B) +- Consider pipeline parallelism for very large models +- Examples: + - 32B model: TP=1-2 (MI300X: TP=2, MI355X: TP=1) + - 70B model: TP=8 + +**For LoRA:** +- Lower `tensor_model_parallel_size` due to reduced memory +- LoRA can fit larger models with less parallelism +- Examples: + - 32B model: TP=1 + - 70B model: TP=8 (still requires high TP due to model size) + +### Learning Rate Guidelines + +- **SFT**: Use lower learning rates (5e-6 to 1e-5) +- **LoRA**: Use higher learning rates (1e-4 to 5e-4) +- Always use warmup for stable training + +### Batch Size Recommendations + +- Start with `global_batch_size: 8` for SFT development +- LoRA can use higher batch sizes (e.g., 32) due to lower memory usage +- Increase for production: 64, 128, or higher +- Adjust `micro_batch_size` (1-4) based on GPU memory and sequence length +- Longer sequences (8192) may require higher `micro_batch_size` for efficiency + +--- + +## 🔍 Troubleshooting + +### Out of Memory (OOM) Errors + +**For SFT:** +1. Increase `tensor_model_parallel_size` +2. Reduce `micro_batch_size` +3. Enable gradient checkpointing: + ```yaml + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 1 + ``` +4. Reduce `seq_length` + +**For LoRA:** +1. LoRA should have lower memory usage; verify `peft: lora` is set +2. Reduce `micro_batch_size` if still facing OOM +3. Enable recomputation as above + +### Training Instability + +1. **Check learning rate**: Reduce if loss is spiking +2. **Increase warmup**: Try `lr_warmup_iters: 100` or higher +3. **Use mixed precision**: Ensure `precision_config: bf16_mixed` +4. **Monitor gradients**: Watch for gradient explosions + +### Slow Training Speed + +1. **Optimize batch size**: Increase `global_batch_size` if possible +2. **Check parallelism**: Ensure optimal TP/PP configuration +3. **Use container mode**: Docker containers can improve performance +4. **Profile execution**: Use profiling tools to identify bottlenecks + +### Configuration Issues + +1. **Verify paths**: Ensure config file paths are correct +2. **Check YAML syntax**: Validate indentation and structure +3. **Environment variables**: Set `PRIMUS_WORKSPACE` if needed +4. **Model checkpoint**: Verify pre-trained checkpoint path (if using) + +--- + +## 🎯 Summary Commands + +**Quick reference for common post-training tasks:** + +```bash +# SFT on MI355X (direct mode) +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml + +# LoRA on MI355X (direct mode) +./runner/primus-cli direct train posttrain \ + --config ./examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml + +# SFT on MI300X (container mode) +./runner/primus-cli container --image rocm/primus:latest train posttrain \ + --config ./examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml +``` + +--- + +**Need help?** Open an issue on [GitHub](https://github.com/AMD-AIG-AIMA/Primus/issues). + +**Start fine-tuning with Primus! 🚀** diff --git a/examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml b/examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml new file mode 100644 index 000000000..456ee968a --- /dev/null +++ b/examples/megatron_bridge/configs/MI300X/qwen3_32b_lora_posttrain.yaml @@ -0,0 +1,58 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:qwen3_32b_lora_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + + # Model to run + model: qwen3_32b.yaml + + overrides: + stderr_sink_level: DEBUG + + # Parallelism configuration + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_dtype: null + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: false + use_megatron_fsdp: false + + # Finetuning-specific params + #pretrained_checkpoint: null + peft: lora + packed_sequence: false + + # Training configuration + train_iters: 200 + global_batch_size: 32 + micro_batch_size: 2 + seq_length: 8192 + eval_interval: 30 + save_interval: 50 + + # Optimizer configuration + finetune_lr: 1.0e-4 + min_lr: 0.0 + lr_warmup_iters: 50 + lr_decay_iters: null + + # W&B logging + wandb_project: null + wandb_entity: null + wandb_exp_name: null + + # Precision + precision_config: bf16_mixed + comm_overlap_config: null + + # Recompute configuration (enabled for 32B model) + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 1 + diff --git a/examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml b/examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml new file mode 100644 index 000000000..c91a6b0d5 --- /dev/null +++ b/examples/megatron_bridge/configs/MI300X/qwen3_32b_sft_posttrain.yaml @@ -0,0 +1,58 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:qwen3_32b_sft_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + + # Model to run + model: qwen3_32b.yaml + + overrides: + stderr_sink_level: DEBUG + + # Parallelism configuration + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + pipeline_dtype: null + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: false + use_megatron_fsdp: false + + # Finetuning-specific params + #pretrained_checkpoint: null + peft: "none" + packed_sequence: false + + # Training configuration + train_iters: 200 + global_batch_size: 8 + micro_batch_size: 2 + seq_length: 8192 + eval_interval: 30 + save_interval: 50 + + # Optimizer configuration + finetune_lr: 5.0e-6 + min_lr: 0.0 + lr_warmup_iters: 50 + lr_decay_iters: null + + # W&B logging + wandb_project: null + wandb_entity: null + wandb_exp_name: null + + # Precision + precision_config: bf16_mixed + comm_overlap_config: null + + # Recompute configuration (enabled for 32B model) + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 1 + diff --git a/examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml b/examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml index 49ce4f8a0..3dfb3eb39 100755 --- a/examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml +++ b/examples/megatron_bridge/configs/MI355X/qwen3_32b_lora_posttrain.yaml @@ -30,9 +30,9 @@ modules: # Training configuration train_iters: 200 - global_batch_size: 128 - micro_batch_size: 1 - seq_length: 2048 + global_batch_size: 32 + micro_batch_size: 4 + seq_length: 8192 eval_interval: 30 save_interval: 50 diff --git a/examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml b/examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml index 1b9f10aff..d01623990 100755 --- a/examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml +++ b/examples/megatron_bridge/configs/MI355X/qwen3_32b_sft_posttrain.yaml @@ -15,7 +15,7 @@ modules: stderr_sink_level: DEBUG # Parallelism configuration - tensor_model_parallel_size: 4 + tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_dtype: null virtual_pipeline_model_parallel_size: null @@ -30,9 +30,9 @@ modules: # Training configuration train_iters: 200 - global_batch_size: 128 + global_batch_size: 8 micro_batch_size: 1 - seq_length: 2048 + seq_length: 8192 eval_interval: 30 save_interval: 50 diff --git a/primus/configs/modules/megatron_bridge/sft_trainer.yaml b/primus/configs/modules/megatron_bridge/sft_trainer.yaml index 230bb22a5..994a1389a 100644 --- a/primus/configs/modules/megatron_bridge/sft_trainer.yaml +++ b/primus/configs/modules/megatron_bridge/sft_trainer.yaml @@ -12,7 +12,7 @@ stage: "sft" # main control flag -enable_primus_turbo: false +enable_primus_turbo: true # feature control flags use_turbo_attention: false From 564fa38e2027283cb12b00487f9d343b5f69e2cc Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Thu, 12 Feb 2026 11:14:55 -0800 Subject: [PATCH 16/23] set grad_accum_fusion=false for triton 3.6.0 compatibility --- .../MI300X/zebra_llama_1B-pretrain.yaml | 2 +- .../MI300X/zebra_llama_3B-pretrain.yaml | 2 +- .../MI300X/zebra_llama_8B-pretrain.yaml | 2 +- .../MI355X/zebra_llama_1B-pretrain.yaml | 70 +++++++++++++++++++ .../MI355X/zebra_llama_3B-pretrain.yaml | 70 +++++++++++++++++++ .../MI355X/zebra_llama_8B-pretrain.yaml | 70 +++++++++++++++++++ 6 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml create mode 100644 examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml create mode 100644 examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml diff --git a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml index 2fa9d8fe5..d2327bab8 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -52,7 +52,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: true + gradient_accumulation_fusion: false # data mock_data: true diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml index 05b0290ab..4daeb25e6 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -52,7 +52,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: true + gradient_accumulation_fusion: false # data mock_data: true diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml index e5c02b91e..a7083c069 100644 --- a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -52,7 +52,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: true + gradient_accumulation_fusion: false # data mock_data: true diff --git a/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml new file mode 100644 index 000000000..4eaf5102c --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_1B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 16 + global_batch_size: 128 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..4daeb25e6 --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-3B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..ff274291e --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 16 + global_batch_size: 128 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.1-8B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch From 1028e654e9e8ff977f1cc572ad7d8b23c4a3b200 Mon Sep 17 00:00:00 2001 From: HuangWei-95 Date: Thu, 5 Feb 2026 14:39:00 +0800 Subject: [PATCH 17/23] refactor(megatron): add tokenizer override patch, move to new arch (#530) Override Megatron build_tokenizer to support custom tokenizer types with HuggingFace Hub IDs - Fixes Llama2Tokenizer failing with Hub IDs in new architecture - All custom types now work consistently in legacy and new architectures --------- Co-authored-by: HuangWei-95 Co-authored-by: Xiaoming-AMD --- .../megatron/patches/checkpoint_patches.py | 76 +++++++++++++++ .../patches/tokenizer_builder_patches.py | 93 +++++++++++++++++++ primus/cli/subcommands/train.py | 5 +- 3 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 primus/backends/megatron/patches/tokenizer_builder_patches.py diff --git a/primus/backends/megatron/patches/checkpoint_patches.py b/primus/backends/megatron/patches/checkpoint_patches.py index 4a56d7f2a..e5a37f7dd 100644 --- a/primus/backends/megatron/patches/checkpoint_patches.py +++ b/primus/backends/megatron/patches/checkpoint_patches.py @@ -46,3 +46,79 @@ def patch_filesystem_writer_async(ctx: PatchContext): log_rank_0( "[Patch:megatron.checkpoint.filesystem_writer_async] Patch FileSystemWriterAsync successfully." ) + + +@register_patch( + "megatron.checkpoint.save_checkpoint", + backend="megatron", + phase="before_train", + description="Wrap save_checkpoint to skip saving at the last iteration", +) +def patch_save_checkpoint(ctx: PatchContext): + """ + Wrap Megatron's save_checkpoint to skip saving at the last iteration + + This patch monkey-patches the save_checkpoint function in + megatron.training.training module to check if: + 1. disable_last_saving is True + 2. Current iteration equals train_iters (final iteration) + + If both conditions are met, the checkpoint save is skipped. + """ + try: + import megatron.training.training as training_module + except ImportError as e: + log_rank_0(f"[Patch:megatron.checkpoint.save_checkpoint] Skip patch (Megatron not available): {e}") + return + + # Save original function + original_save_checkpoint = training_module.save_checkpoint + + # The following signature is used to match the original Megatron save_checkpoint interface, + # but the wrapper will only use a subset of the arguments as handled below. + def wrapped_save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context=None, + pipeline_rank=None, + expert_rank=None, + tensor_rank=None, + pipeline_parallel=None, + expert_parallel=None, + non_persistent_ckpt=False, + train_data_iterator=None, + preprocess_common_state_dict_fn=None, + release=False, + ): + args = ctx.extra.get("backend_args", {}) + + if args.disable_last_saving and iteration == args.train_iters: + log_rank_0( + f"[Patch:megatron.checkpoint.save_checkpoint] Skip saving at the last iteration: {iteration}" + ) + return + + # Call the original save_checkpoint function with explicit keyword arguments for clarity. + return original_save_checkpoint( + iteration, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + checkpointing_context=checkpointing_context, + pipeline_rank=pipeline_rank, + expert_rank=expert_rank, + tensor_rank=tensor_rank, + pipeline_parallel=pipeline_parallel, + expert_parallel=expert_parallel, + non_persistent_ckpt=non_persistent_ckpt, + train_data_iterator=train_data_iterator, + preprocess_common_state_dict_fn=preprocess_common_state_dict_fn, + release=release, + ) + + training_module.save_checkpoint = wrapped_save_checkpoint + log_rank_0("[Patch:megatron.checkpoint.save_checkpoint] Patch save_checkpoint successfully.") diff --git a/primus/backends/megatron/patches/tokenizer_builder_patches.py b/primus/backends/megatron/patches/tokenizer_builder_patches.py new file mode 100644 index 000000000..1cb47eb40 --- /dev/null +++ b/primus/backends/megatron/patches/tokenizer_builder_patches.py @@ -0,0 +1,93 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Megatron Tokenizer Builder Patches + +Override Megatron's build_tokenizer to use Primus version which properly +handles custom tokenizer types (Llama2Tokenizer, Llama3Tokenizer, etc.) +with HuggingFace Hub ID support. + +Background: +----------- +Megatron's official _Llama2Tokenizer only supports local SentencePiece files, +while Primus extends it to support HuggingFace Hub IDs (e.g., meta-llama/Llama-2-7b-hf). + +Without this patch, the new architecture (PrimusRuntime) would call Megatron's +official build_tokenizer, causing failures when using custom tokenizer types +with Hub IDs. + +This patch ensures both legacy and new architectures use the same tokenizer +building logic. +""" + +from primus.core.patches import PatchContext, register_patch +from primus.modules.module_utils import log_rank_0 + + +@register_patch( + "megatron.tokenizer.build_tokenizer_override", + backend="megatron", + phase="setup", + description="Override Megatron's build_tokenizer to support Primus custom tokenizer types with HuggingFace Hub IDs", +) +def patch_build_tokenizer_override(ctx: PatchContext): + """ + Monkey-patch Megatron's build_tokenizer with Primus version. + + This ensures that custom tokenizer types (Llama2Tokenizer, Llama3Tokenizer, + DeepSeekV2Tokenizer, etc.) are properly handled: + + - All custom types use _HuggingFaceTokenizer internally + - Support for HuggingFace Hub IDs (e.g., meta-llama/Llama-2-7b-hf) + - Consistent behavior between legacy and new architectures + + Without this patch: + ------------------- + - tokenizer_type: Llama2Tokenizer + tokenizer_model: meta-llama/Llama-2-7b-hf + → Calls Megatron's _Llama2Tokenizer + → Expects local file path + → ❌ FileNotFoundError + + With this patch: + ---------------- + - tokenizer_type: Llama2Tokenizer + tokenizer_model: meta-llama/Llama-2-7b-hf + → Calls Primus build_tokenizer + → Maps to _HuggingFaceTokenizer + → Supports Hub ID + → ✅ Success + """ + try: + import megatron.training.global_vars as megatron_global_vars + import pretrain_gpt + except ImportError as e: + log_rank_0( + f"[Patch:megatron.tokenizer.build_tokenizer_override] " + f"Skip patch (Megatron not available): {e}" + ) + return + + # Import Primus build_tokenizer + from primus.backends.megatron.training.tokenizer.tokenizer import ( + build_tokenizer as primus_build_tokenizer, + ) + + # Save original for reference (optional) + if not hasattr(megatron_global_vars, "_original_build_tokenizer"): + megatron_global_vars._original_build_tokenizer = megatron_global_vars.build_tokenizer + if not hasattr(pretrain_gpt, "_original_build_tokenizer"): + pretrain_gpt._original_build_tokenizer = pretrain_gpt.build_tokenizer + + # Replace Megatron's build_tokenizer with Primus version + megatron_global_vars.build_tokenizer = primus_build_tokenizer + pretrain_gpt.build_tokenizer = primus_build_tokenizer + + log_rank_0( + "[Patch:megatron.tokenizer.build_tokenizer_override] " + "✓ Replaced Megatron build_tokenizer with Primus version" + ) diff --git a/primus/cli/subcommands/train.py b/primus/cli/subcommands/train.py index 1f07fdbd1..049612dca 100644 --- a/primus/cli/subcommands/train.py +++ b/primus/cli/subcommands/train.py @@ -17,7 +17,7 @@ def _resolve_pretrain_runtime(args) -> str: Priority: 1) Explicit env override via PRIMUS_TRAIN_RUNTIME - 2) Auto-detect by backend framework (TorchTitan -> core, others -> legacy) + 2) Auto-detect by backend framework (TorchTitan Megatron -> core, others -> legacy) """ runtime_entry = getenv("PRIMUS_TRAIN_RUNTIME", "").strip().lower() if runtime_entry in ("legacy", "core"): @@ -38,7 +38,8 @@ def _resolve_pretrain_runtime(args) -> str: except Exception: framework = None - return "core" if framework == "torchtitan" else "legacy" + supported_frameworks = ["torchtitan", "megatron"] + return "core" if framework in supported_frameworks else "legacy" def run(args, overrides: List[str]): From 2edff4eb6a088c8adac5695b1f257bdbe1a40602 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Thu, 19 Feb 2026 23:09:06 +0000 Subject: [PATCH 18/23] Fix: Add model_type detection for mamba/hybrid models in core runtime - Update MegatronPretrainTrainer.run_train() to detect model_type from backend_args - Conditionally import pretrain_mamba or pretrain_gpt based on model_type - Pass model_type to get_model_provider() to use correct builder (mamba_builder vs gpt_builder) - Restore core runtime support for megatron as intended by commit cfe8cc07 - Fixes 'specialize for HybridStack' error when using core runtime with hybrid models --- .../megatron/megatron_pretrain_trainer.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/primus/backends/megatron/megatron_pretrain_trainer.py b/primus/backends/megatron/megatron_pretrain_trainer.py index 18c445979..f0b706850 100644 --- a/primus/backends/megatron/megatron_pretrain_trainer.py +++ b/primus/backends/megatron/megatron_pretrain_trainer.py @@ -98,13 +98,27 @@ def run_train(self): # Import Megatron components from megatron.core.enums import ModelType from megatron.training import pretrain # type: ignore - from pretrain_gpt import ( # type: ignore - forward_step, - train_valid_test_datasets_provider, - ) from primus.core.utils.import_utils import get_model_provider + # Determine model type (gpt or mamba) from backend_args + model_type = getattr(self.backend_args, "model_type", "gpt") + log_rank_0(f"-detected model_type: {model_type}") + + # Import the appropriate training components based on model_type + if model_type == "mamba": + from pretrain_mamba import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + log_rank_0("Using Mamba model provider and training components") + else: + from pretrain_gpt import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + log_rank_0("Using GPT model provider and training components") + # Configure training components if hasattr(train_valid_test_datasets_provider, "is_distributed"): train_valid_test_datasets_provider.is_distributed = True @@ -133,9 +147,13 @@ def run_train(self): if "store" in sig.parameters: kwargs["store"] = store + # Get model provider with correct model_type + model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-model_provider: {model_provider}") + wrapped_pretrain( train_valid_test_datasets_provider, - get_model_provider(), + model_provider, ModelType.encoder_or_decoder, forward_step, **kwargs, From 3ff8294d938d30a5020a8c012f5150047fc7ccfb Mon Sep 17 00:00:00 2001 From: HuangWei-95 Date: Fri, 6 Feb 2026 09:49:19 +0800 Subject: [PATCH 19/23] ci(deterministic): add env for megatron ci test (#539) add env for TestMegatronTrainerDeterministic ci test Co-authored-by: HuangWei-95 --- tests/trainer/test_megatron_trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/trainer/test_megatron_trainer.py b/tests/trainer/test_megatron_trainer.py index 37200b2ed..24233f72f 100644 --- a/tests/trainer/test_megatron_trainer.py +++ b/tests/trainer/test_megatron_trainer.py @@ -352,6 +352,10 @@ def test_llama3_8B(self): "PRIMUS_NUM_LAYERS": "4", # deterministic vars "PRIMUS_DETERMINISTIC": "1", + "NCCL_ALGO": "Ring", + "TORCH_COMPILE_DISABLE": "1", + "ROCBLAS_DEFAULT_ATOMICS_MODE": "0", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", } stdout, _ = run_script( self.__class__.__name__, @@ -379,6 +383,10 @@ def test_deepseek_v2_lite(self): "PRIMUS_NUM_LAYERS": "4", # deterministic vars "PRIMUS_DETERMINISTIC": "1", + "NCCL_ALGO": "Ring", + "TORCH_COMPILE_DISABLE": "1", + "ROCBLAS_DEFAULT_ATOMICS_MODE": "0", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", } stdout, _ = run_script( self.__class__.__name__, From 50b30e0acc51a33dcc86b2253a40da2408bc32c6 Mon Sep 17 00:00:00 2001 From: WangLingxun Date: Fri, 6 Feb 2026 10:03:07 +0800 Subject: [PATCH 20/23] Update Docker base image from v25.10 to v26.1 (#534) Update all references to the Primus Docker base image across documentation, configuration files, CI/CD workflows, and example scripts to use the latest v26.1 release. --- .github/workflows/ci.yaml | 2 +- .github/workflows/docker/Dockerfile | 2 +- .github/workflows/docker/Dockerfile.ainic | 3 ++- README.md | 4 ++-- docs/cli/PRIMUS-CLI-GUIDE.md | 6 +++--- docs/cli/README.md | 2 +- docs/quickstart.md | 8 ++++---- examples/README.md | 6 +++--- examples/run_k8s_pretrain.sh | 4 ++-- examples/run_local_pretrain.sh | 4 ++-- examples/run_local_pretrain_cli.sh | 2 +- examples/run_slurm_pretrain_cli.sh | 2 +- runner/.primus.yaml | 2 +- runner/README.md | 2 +- runner/primus-cli-container.sh | 2 +- tools/docker/start_container.sh | 2 +- 16 files changed, 27 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2eca2d5bd..7e24422bd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,7 +12,7 @@ on: env: PRIMUS_TURBO_COMMIT: a4f1ff4935059c249721697c19bd2b22dfd9417d # feat(mxfp4/mxfp8): refine quantization api and optimize kernels perf (#207) ROCSHMEM_COMMIT: 17ff985c026f9f97f85068647e863ab541dd5645 # Update version to 3.2.0 for 7.2.0 rocm release (#351) (#355) - BASE_IMAGE: docker.io/rocm/primus:v25.10 + BASE_IMAGE: docker.io/rocm/primus:v26.1 MAXTEXT_BASE_IMAGE: docker.io/rocm/jax-training:maxtext-v25.9 jobs: diff --git a/.github/workflows/docker/Dockerfile b/.github/workflows/docker/Dockerfile index 539cdd117..444326c3f 100644 --- a/.github/workflows/docker/Dockerfile +++ b/.github/workflows/docker/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=docker.io/rocm/primus:v25.10 +ARG BASE_IMAGE=docker.io/rocm/primus:v26.1 FROM ${BASE_IMAGE} ARG PRIMUS_TURBO_COMMIT diff --git a/.github/workflows/docker/Dockerfile.ainic b/.github/workflows/docker/Dockerfile.ainic index 4e2ddd6eb..c6c951622 100644 --- a/.github/workflows/docker/Dockerfile.ainic +++ b/.github/workflows/docker/Dockerfile.ainic @@ -54,7 +54,8 @@ ENV RCCL_HOME=${WORKDIR}/rccl # Build AMD ANP # --------------------------------------------------------------------------- -RUN cd ${WORKDIR} && git clone https://github.com/rocm/amd-anp.git && \ +RUN apt-get install -y --allow-unauthenticated libionic-dev && \ + cd ${WORKDIR} && git clone https://github.com/rocm/amd-anp.git && \ cd amd-anp && git checkout tags/v1.3.0 && \ make -j 16 RCCL_HOME=${RCCL_HOME} \ MPI_INCLUDE=${MPI_PATH}/include/ \ diff --git a/README.md b/README.md index 18a61e3aa..2e1d69e52 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Primus leverages AMD’s ROCm Docker images to provide a consistent, ready-to-ru 1. **Pull the latest Docker image** ```bash - docker pull docker.io/rocm/primus:v25.10 + docker pull docker.io/rocm/primus:v26.1 ``` 2. **Clone the repository** @@ -74,7 +74,7 @@ Primus leverages AMD’s ROCm Docker images to provide a consistent, ready-to-ru # Run training in container # NOTE: If your config downloads weights/tokenizer from Hugging Face Hub, # you typically need to pass HF_TOKEN into the container. - ./primus-cli container --image rocm/primus:v25.10 \ + ./primus-cli container --image rocm/primus:v26.1 \ --env HF_TOKEN="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" \ -- train pretrain --config examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml ``` diff --git a/docs/cli/PRIMUS-CLI-GUIDE.md b/docs/cli/PRIMUS-CLI-GUIDE.md index 775f7a760..3ec7a2e5b 100644 --- a/docs/cli/PRIMUS-CLI-GUIDE.md +++ b/docs/cli/PRIMUS-CLI-GUIDE.md @@ -97,7 +97,7 @@ Primus CLI supports three execution modes, each suitable for different scenarios **Common Options**: | Option | Description | Example | |--------|-------------|---------| -| `--image IMAGE` | Specify container image | `--image rocm/primus:v25.10` | +| `--image IMAGE` | Specify container image | `--image rocm/primus:v26.1` | | `--volume PATH[:PATH]` | Mount directory | `--volume /data:/data` | | `--cpus N` | Limit CPU cores | `--cpus 16` | | `--memory SIZE` | Limit memory size | `--memory 128G` | @@ -207,7 +207,7 @@ slurm: # Container configuration container: - image: "rocm/primus:v25.10" + image: "rocm/primus:v26.1" options: cpus: "32" memory: "256G" @@ -628,7 +628,7 @@ Step 4: primus-cli-container.sh (on each node) ├─ Load container.* config (image, devices, mounts, etc.) ├─ Parse container params: --image rocm/megatron-lm:v25.8_py310 ├─ Merge config and CLI params - │ Config: image=rocm/primus:v25.10 + │ Config: image=rocm/primus:v26.1 │ CLI: --image rocm/megatron-lm:v25.8_py310 │ Result: image=rocm/megatron-lm:v25.8_py310 ├─ Build container options diff --git a/docs/cli/README.md b/docs/cli/README.md index 347e29761..313fb23a2 100644 --- a/docs/cli/README.md +++ b/docs/cli/README.md @@ -48,7 +48,7 @@ primus-cli direct -- benchmark gemm -M 4096 -N 4096 -K 4096 | Mode | Use Case | Command Example | |------|----------|-----------------| | **Direct** | Local development, quick validation | `primus-cli direct -- train pretrain` | -| **Container** | Environment isolation, dependency management | `primus-cli container --image rocm/primus:v25.10 -- train pretrain` | +| **Container** | Environment isolation, dependency management | `primus-cli container --image rocm/primus:v26.1 -- train pretrain` | | **Slurm** | Multi-node distributed training | `primus-cli slurm srun -N 8 -- train pretrain` | ## 📖 Learn More diff --git a/docs/quickstart.md b/docs/quickstart.md index 92ef5dc7e..7f1685581 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -21,7 +21,7 @@ rocm-smi && docker --version ```bash # Pull Docker image -docker pull docker.io/rocm/primus:v25.10 +docker pull docker.io/rocm/primus:v26.1 # Clone repository git clone --recurse-submodules https://github.com/AMD-AIG-AIMA/Primus.git @@ -32,7 +32,7 @@ cd Primus ```bash # Run a quick benchmark in container -./primus-cli container --image rocm/primus:v25.10 \ +./primus-cli container --image rocm/primus:v26.1 \ -- benchmark gemm -M 4096 -N 4096 -K 4096 ``` @@ -50,7 +50,7 @@ Use the Docker image you just pulled: ```bash # Run training in container (recommended for getting started) -./primus-cli container --image rocm/primus:v25.10 \ +./primus-cli container --image rocm/primus:v26.1 \ -- train pretrain --config examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml ``` @@ -62,7 +62,7 @@ Use the Docker image you just pulled: --config examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml # Slurm mode (for multi-node cluster) -./primus-cli slurm srun -N 8 -p gpu -- container --image rocm/primus:v25.10 \ +./primus-cli slurm srun -N 8 -p gpu -- container --image rocm/primus:v26.1 \ -- train pretrain --config examples/megatron/configs/MI300X/llama2_7B-pretrain.yaml ``` diff --git a/examples/README.md b/examples/README.md index f26063f0c..9d1c30853 100644 --- a/examples/README.md +++ b/examples/README.md @@ -49,7 +49,7 @@ We recommend using the official [rocm/megatron-lm Docker image](https://hub.dock ```bash # Pull the latest Docker image -docker pull docker.io/rocm/primus:v25.10 +docker pull docker.io/rocm/primus:v26.1 ``` @@ -126,7 +126,7 @@ Multi-node training is launched via **SLURM**. Specify the number of nodes and the model config: ```bash -export DOCKER_IMAGE="docker.io/rocm/primus:v25.10" +export DOCKER_IMAGE="docker.io/rocm/primus:v26.1" export NNODES=8 # Example for megatron llama3.1_8B @@ -285,7 +285,7 @@ When using the `create` command to start a new training workload, the following | `--gpu` | Number of GPUs | 8 | | `--exp` | Path to experiment (training config) file (required) | — | | `--data_path` | Path to training data | — | -| `--image` | Docker image to use | `docker.io/rocm/primus:v25.10` | +| `--image` | Docker image to use | `docker.io/rocm/primus:v26.1` | | `--hf_token` | HuggingFace token | Read from env var `HF_TOKEN` | | `--workspace` | Workspace name | `primus-safe-pretrain` | | `--nodelist` | Comma-separated list of node hostnames to run on | — | diff --git a/examples/run_k8s_pretrain.sh b/examples/run_k8s_pretrain.sh index 8e07074d7..a74527d16 100644 --- a/examples/run_k8s_pretrain.sh +++ b/examples/run_k8s_pretrain.sh @@ -15,7 +15,7 @@ GPU="8" EXP_PATH="" DATA_PATH="" BACKEND="megatron" -IMAGE="docker.io/rocm/primus:v25.10" +IMAGE="docker.io/rocm/primus:v26.1" HF_TOKEN="${HF_TOKEN:-}" WORKSPACE="primus-safe-pretrain" NODELIST="" @@ -38,7 +38,7 @@ Options for create: --backend Training backend, e.g. megatron | torchtitan(default: megatron) --exp Path to EXP config (optional) --data_path Data path (optional) - --image Docker image to use (default: docker.io/rocm/primus:v25.10) + --image Docker image to use (default: docker.io/rocm/primus:v26.1) --hf_token HuggingFace token (default: from env HF_TOKEN) --workspace Workspace name (default: safe-cluster-dev) --nodelist Comma-separated list of node names to run on (optional) diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 04fe9720b..4cb36fdb3 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -16,7 +16,7 @@ Usage: bash run_local_pretrain.sh This script launches a Primus pretraining task inside a Docker/Podman container. Environment Variables: - DOCKER_IMAGE Docker image to use [Default: docker.io/rocm/primus:v25.10] + DOCKER_IMAGE Docker image to use [Default: docker.io/rocm/primus:v26.1] MASTER_ADDR Master node IP or hostname [Default: localhost] MASTER_PORT Master node port [Default: 1234] NNODES Total number of nodes [Default: 1] @@ -44,7 +44,7 @@ EXP=${EXP:-"examples/megatron/exp_pretrain.yaml"} if [ "${BACKEND:-}" = "MaxText" ]; then DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v25.9"} else - DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v25.10"} + DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} fi # Project root diff --git a/examples/run_local_pretrain_cli.sh b/examples/run_local_pretrain_cli.sh index 4b77d9f56..4c3fbd39a 100755 --- a/examples/run_local_pretrain_cli.sh +++ b/examples/run_local_pretrain_cli.sh @@ -15,7 +15,7 @@ EXP=${EXP:-"examples/megatron/exp_pretrain.yaml"} if [ "${BACKEND:-}" = "MaxText" ]; then DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/jax-training:maxtext-v25.9"} else - DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v25.10"} + DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} fi # ------------------ Cluster Env Defaults ------------------ diff --git a/examples/run_slurm_pretrain_cli.sh b/examples/run_slurm_pretrain_cli.sh index 408a5c142..efa7e8910 100755 --- a/examples/run_slurm_pretrain_cli.sh +++ b/examples/run_slurm_pretrain_cli.sh @@ -35,7 +35,7 @@ mkdir -p "$LOG_DIR" # NOTE: The --env entries below are passed into the container and will be visible # to the Primus training process (and system hooks) inside the container. bash "$PRIMUS_PATH/runner/primus-cli" slurm "${SLURM_ARGS[@]}" \ --- --image "${DOCKER_IMAGE:-rocm/primus:v25.10}" \ +-- --image "${DOCKER_IMAGE:-rocm/primus:v26.1}" \ -- \ --env "USING_AINIC=${USING_AINIC:-0}" \ --env "PATCH_TE_FLASH_ATTN=${PATCH_TE_FLASH_ATTN:-0}" \ diff --git a/runner/.primus.yaml b/runner/.primus.yaml index 56244fb57..989a7d6d1 100644 --- a/runner/.primus.yaml +++ b/runner/.primus.yaml @@ -26,7 +26,7 @@ container: # All keys directly map to CLI arguments (--key value) options: # Container image - image: "rocm/primus:v25.10" + image: "rocm/primus:v26.1" # Single-value options ipc: "host" diff --git a/runner/README.md b/runner/README.md index 7d771ec08..81809bc09 100644 --- a/runner/README.md +++ b/runner/README.md @@ -162,7 +162,7 @@ This is the most common pattern when you need a fixed software stack. bash runner/primus-cli slurm \ -N 4 \ --nodelist "node[01-04]" \ --- --image "rocm/primus:v25.10" \ +-- --image "rocm/primus:v26.1" \ -- --env NCCL_DEBUG=INFO \ -- train pretrain --config examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml ``` diff --git a/runner/primus-cli-container.sh b/runner/primus-cli-container.sh index 987637bad..c73d12878 100755 --- a/runner/primus-cli-container.sh +++ b/runner/primus-cli-container.sh @@ -45,7 +45,7 @@ Docker/Podman Options: --cap-add Add Linux capabilities (e.g., SYS_PTRACE) Container Configuration: - --image Docker image [default: rocm/primus:v25.10] + --image Docker image [default: rocm/primus:v26.1] --name Container name --user Run as specific user (e.g., 1000:1000) --network Network mode (e.g., host, bridge) diff --git a/tools/docker/start_container.sh b/tools/docker/start_container.sh index 686fb30b0..8396ea1c9 100755 --- a/tools/docker/start_container.sh +++ b/tools/docker/start_container.sh @@ -6,7 +6,7 @@ ############################################################################### PRIMUS_PATH=$(realpath "$(dirname "$0")/../..") -DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v25.10"} +DOCKER_IMAGE=${DOCKER_IMAGE:-"docker.io/rocm/primus:v26.1"} DATA_PATH=${DATA_PATH:-"${PRIMUS_PATH}/data"} SANITIZED_USER=$(echo "${USER:-unknown}" | tr -cd '[:alnum:]_-') if [ -z "$SANITIZED_USER" ]; then From 07e32b90afa3c82684747ea76669b69190487b68 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Fri, 20 Feb 2026 00:12:40 +0000 Subject: [PATCH 21/23] Merge main into clairlee/dev/hybrid: Resolve conflicts and preserve model_type detection --- primus/backends/megatron/megatron_pretrain_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/primus/backends/megatron/megatron_pretrain_trainer.py b/primus/backends/megatron/megatron_pretrain_trainer.py index ae201dc06..d72bef3e2 100644 --- a/primus/backends/megatron/megatron_pretrain_trainer.py +++ b/primus/backends/megatron/megatron_pretrain_trainer.py @@ -32,12 +32,14 @@ def train(self): forward_step, train_valid_test_datasets_provider, ) + log_rank_0("Using Mamba model provider and training components") else: from pretrain_gpt import ( # type: ignore forward_step, train_valid_test_datasets_provider, ) + log_rank_0("Using GPT model provider and training components") # Configure training components From 07a546288b74c128d54b6574ac6b0113bb3e1c51 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Fri, 20 Feb 2026 01:31:42 +0000 Subject: [PATCH 22/23] resolve unit test error --- primus/backends/megatron/megatron_pretrain_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/primus/backends/megatron/megatron_pretrain_trainer.py b/primus/backends/megatron/megatron_pretrain_trainer.py index d72bef3e2..97ce5e869 100644 --- a/primus/backends/megatron/megatron_pretrain_trainer.py +++ b/primus/backends/megatron/megatron_pretrain_trainer.py @@ -67,7 +67,11 @@ def train(self): kwargs["store"] = store # Get model provider with correct model_type - model_provider = get_model_provider(model_type=model_type) + # Only pass model_type if it's not the default to maintain compatibility + if model_type != "gpt": + model_provider = get_model_provider(model_type=model_type) + else: + model_provider = get_model_provider() log_rank_0(f"-model_provider: {model_provider}") wrapped_pretrain( From f4e79edfb7b378742b3705b8ad6d439cc54548c2 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Fri, 20 Feb 2026 05:59:59 +0000 Subject: [PATCH 23/23] use gpt model provider by default for compatibility --- primus/modules/trainer/megatron/trainer.py | 23 +++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index d3f19f1b3..8f0c9f70e 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -508,16 +508,25 @@ def update_primus_config( if not hasattr(args, "router_logit_softcapping"): args.router_logit_softcapping = None + # Only pass model_type parameter when it's "mamba" to maintain backward compatibility + # with main branch behavior for "gpt" (default) case if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial( - primus_model_provider, get_model_provider(model_type=model_type) - ) + if model_type == "mamba": + self.model_provider = functools.partial( + primus_model_provider, get_model_provider(model_type=model_type) + ) + else: + self.model_provider = functools.partial(primus_model_provider, get_model_provider()) else: - log_rank_0(f"-getting model provider for model_type={model_type}") - model_provider = get_model_provider(model_type=model_type) - log_rank_0(f"-model_provider: {model_provider}") - self.model_provider = model_provider + if model_type == "mamba": + log_rank_0(f"-getting model provider for model_type={model_type}") + model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-model_provider: {model_provider}") + self.model_provider = model_provider + else: + # For "gpt" (default), call without arguments to match main branch behavior + self.model_provider = get_model_provider() if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}")