diff --git a/.gitmodules b/.gitmodules index e3276c3d6..4a0a36fd6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,6 +15,10 @@ [submodule "third_party/Megatron-Bridge"] path = third_party/Megatron-Bridge url = https://github.com/NVIDIA-NeMo/Megatron-Bridge.git +[submodule "third_party/mamba"] + path = third_party/mamba + url = https://github.com/AndreasKaratzas/mamba.git + branch = enable-primus-hybrid-models [submodule "third_party/HummingbirdXT"] path = third_party/HummingbirdXT url = https://github.com/AMD-AGI/HummingbirdXT.git diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml new file mode 100644 index 000000000..469913761 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -0,0 +1,85 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_370M-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_370M.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Pretrain" + # disable_wandb: false + # disable_tensorboard: false + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 50 + micro_batch_size: 4 + global_batch_size: 256 + + seq_length: 2048 + max_position_embeddings: 2048 + + lr: 3.0e-4 + min_lr: 0.0 + lr_warmup_iters: 50000 + lr_decay_iters: 73192188 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: EleutherAI/gpt-neox-20b + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + + # Turbo - may need to disable for Mamba if not supported + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + # Cross entropy flags + # cross_entropy_fusion_impl: "native" + # cross_entropy_loss_fusion: false diff --git a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml new file mode 100644 index 000000000..d2327bab8 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_1B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..4daeb25e6 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-3B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..a7083c069 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.1-8B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml new file mode 100644 index 000000000..4eaf5102c --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_1B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 16 + global_batch_size: 128 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..4daeb25e6 --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-3B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..ff274291e --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 16 + global_batch_size: 128 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.1-8B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/primus/backends/megatron/core/models/hybrid/__init__.py b/primus/backends/megatron/core/models/hybrid/__init__.py new file mode 100644 index 000000000..05d2d673a --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Hybrid Mamba+MLA layer specifications for Megatron-LM.""" + +from .hybrid_mamba_mla_layer_specs import hybrid_stack_spec + +__all__ = ["hybrid_stack_spec"] diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_block.py b/primus/backends/megatron/core/models/hybrid/hybrid_block.py new file mode 100644 index 000000000..3c4d7def8 --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_block.py @@ -0,0 +1,403 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.transformer import TransformerConfig +from torch import Tensor, nn + +# CudaGraphScope is not available in older Megatron versions +try: + from megatron.core.transformer.enums import CudaGraphScope + + HAS_CUDA_GRAPH_SCOPE = True +except ImportError: + CudaGraphScope = None + HAS_CUDA_GRAPH_SCOPE = False + +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + make_viewless_tensor, +) + + +@dataclass +class HybridStackSubmodules: + """ + A class for the module specs for the MambaStack. + """ + + mamba_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + mlp_layer: Union[ModuleSpec, type] = IdentityOp + moe_layer: Union[ModuleSpec, type] = IdentityOp + + +class HybridStack(MegatronModule): + """ + Constructor for the HybridStack class. + + Args: + config (TransformerConfig): the model configuration + submodules (MambaStackSubmodules): the submodules for the stack + residual_in_fp32 (bool, optional): whether to do residual connections + in fp32. Defaults to False. + pre_process (bool, optional): whether to include an embedding layer. + Defaults to True. + hybrid_attention_ratio (float, optional): the target ratio of attention layers to + total layers. Defaults to 0.0. + hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total + layers. Defaults to 0.0. + hybrid_override_pattern (str, optional): the hybrid layer pattern to override + with. Defaults to None. + post_layer_norm (bool, optional): whether to include a final layer norm. + Defaults to True. + post_process (bool, optional): whether to include an output layer. + Defaults to True. + device (optional): the device to use. Defaults to None. + dtype (optional): the data type to use. Defaults to None. + pg_collection (ProcessGroupCollection): the required model communication + process groups to use. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: HybridStackSubmodules, + residual_in_fp32=False, + pre_process: bool = True, + hybrid_attention_ratio: float = 0.0, + hybrid_mlp_ratio: float = 0.0, + hybrid_override_pattern: str = None, + post_layer_norm: bool = True, + post_process: bool = True, + device=None, + dtype=None, + pg_collection: ProcessGroupCollection = None, + ) -> None: + super().__init__(config=config) + self.residual_in_fp32 = residual_in_fp32 + self.pre_process = pre_process + self.post_layer_norm = post_layer_norm + self.post_process = post_process + + assert pg_collection is not None, "pg_collection must be provided for MambaStack" + + self.pp_group = pg_collection.pp + self.tp_group = pg_collection.tp + + # Required for pipeline parallel schedules + self.input_tensor = None + + self.hybrid_attention_ratio = hybrid_attention_ratio + self.hybrid_mlp_ratio = hybrid_mlp_ratio + self.hybrid_override_pattern = hybrid_override_pattern + + # Customized layer allocation + # hybrid_mlp_ratio is not used in this hybrid stack. + # It is by default to be always followed by mamba or mla (i.e., mamba + MLP or MLA + MLP) + # By setting hybrid_attention_ratio, attention layers are by default to be distributed uniformly. + self.layer_type_list = self.allocate_layers( + self.config.num_layers, + self.hybrid_attention_ratio, + ) + + pp_layer_offset = 0 + if self.pp_group.size() > 1: + pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( + self.layer_type_list + ) + + print(f"layer_type_list: {self.layer_type_list}") + + self.layers = nn.ModuleList() + for i, layer_type in enumerate(self.layer_type_list): + fp8_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + with fp8_init_context: + if layer_type == LayerSymbols.MAMBA: + layer = build_module( + submodules.mamba_layer, + config=self.config, + residual_in_fp32=residual_in_fp32, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.ATTENTION: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.attention_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MLP: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.mlp_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MOE: + # Transformer layers apply their own pp_layer_offset + layer = build_module(submodules.moe_layer, config=self.config, layer_number=i + 1) + else: + assert False, "unexpected layer_type" + self.layers.append(layer) + + # Required for activation recomputation + self.num_layers_per_pipeline_rank = len(self.layers) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_norm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + def allocate_layers(self, num_layers, hybrid_attention_ratio): + layer_type_list = [] + num_attention_layers = int(num_layers // 2 * hybrid_attention_ratio) + num_mamba_layers = num_layers // 2 - num_attention_layers + num_mamba_per_attention_layer = num_mamba_layers // num_attention_layers + + if hybrid_attention_ratio <= 0.5: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [ + LayerSymbols.MAMBA, + LayerSymbols.MLP, + ] * num_mamba_per_attention_layer + layer_type_list += base_block * num_attention_layers + layer_type_list += [LayerSymbols.MAMBA, LayerSymbols.MLP] * ( + num_mamba_layers % num_attention_layers + ) + else: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] + layer_type_list += [LayerSymbols.ATTENTION, LayerSymbols.MLP] * ( + num_attention_layers - num_mamba_layers + ) + layer_type_list += base_block * num_mamba_layers + return layer_type_list + + def _select_layers_for_pipeline_parallel(self, layer_type_list): + num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() + + assert self.config.virtual_pipeline_model_parallel_size is None, ( + "The Mamba hybrid model does not currently support " "virtual/interleaved pipeline parallelism" + ) + + offset = self.pp_group.rank() * num_layers_per_pipeline_rank + selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + + return offset, selected_list + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int]]]: + """ + Returns the Mamba conv and ssm states shapes per input sequence + if this block contains Mamba layers (this may not be the case with PP > 1). + """ + for layer_type, layer in zip(self.layer_type_list, self.layers): + if layer_type == LayerSymbols.MAMBA: + return layer.mamba_state_shapes_per_request() + return None + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Tensor, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Forward function of the MambaStack class. + + It either returns the Loss values if labels are given or the + final hidden units + + Args: + hidden_states (Union[Tensor, WrappedTensor]): the input tensor. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): the attention mask. + inference_context (BaseInferenceContext): the inference parameters. + rotary_pos_emb (Tensor, optional): the rotary positional embeddings. + Defaults to None. + Returns: + Tensor: the output tensor. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if inference_context and inference_context.is_static_batching(): + # NOTE(bnorick): match BaseInferenceContext attributes for + # mamba_ssm.utils.generation.BaseInferenceContext, + # this hack supports eval + inference_context.max_seqlen = inference_context.max_sequence_length + inference_context.seqlen_offset = inference_context.sequence_len_offset + + if ( + ( + ( + HAS_CUDA_GRAPH_SCOPE + and self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + ) + or self.config.flash_decode + ) + and inference_context + and inference_context.is_static_batching() + and not self.training + ): + current_batch_size = hidden_states.shape[1] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device="cuda", + ) + else: + sequence_len_offset = None + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with outer_fp8_context: + for layer in self.layers: + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + with inner_fp8_context: + if isinstance(layer, TransformerLayer): + hidden_states, _ = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + ) + else: # MambaLayer + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + ) + + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_norm(hidden_states) + + # Ensure that the tensor passed between pipeline parallel stages is + # viewless. See related notes in TransformerBlock and TransformerLayer + return make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Optional[tuple] = None, + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """ + Returns a sharded state dictionary for the current object. + + This function constructs a sharded state dictionary by iterating over the layers + in the current object, computing the sharded state dictionary for each layer, + and combining the results into a single dictionary. + + Parameters: + prefix (str): The prefix to use for the state dictionary keys. + sharded_offsets (tuple): The sharded offsets to use for the state dictionary. + metadata (dict): Additional metadata to use when computing the sharded state dictionary. + + Returns: + dict: The sharded state dictionary for the current object. + """ + + sharded_state_dict = {} + layer_prefix = f"{prefix}layers." + + for local_layer_idx, layer in enumerate(self.layers): + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = f"{layer_prefix}{local_layer_idx}." # module list index in MambaBlock + + sharded_prefix = f"{layer_prefix}{global_layer_offset}." + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default( + module, + f"{prefix}{name}.", + sharded_offsets, + metadata, + tp_group=self.tp_group, + ) + ) + + return sharded_state_dict diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py new file mode 100644 index 000000000..cb801809d --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.ssm.mlp_layer import MLPLayer +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) + +# Import HybridStack from relative path +from primus.backends.megatron.core.models.hybrid.hybrid_block import ( + HybridStack, + HybridStackSubmodules, +) + +# Inference layers may not be available in older Megatron versions +# They're only used in hybrid_inference_stack_spec, not the training spec +try: + from megatron.core.tensor_parallel import ( + InferenceLayerNormColumnParallelLinear, + InferenceRowParallelLinear, + ) + + HAS_INFERENCE_LAYERS = True +except ImportError: + # Fallback to regular layers for inference spec + InferenceLayerNormColumnParallelLinear = TELayerNormColumnParallelLinear + InferenceRowParallelLinear = TERowParallelLinear + HAS_INFERENCE_LAYERS = False + +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, +) + +moe = get_moe_module_spec( + use_te=True, + num_experts=8, # Can be any positive integer (must not be None). + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, +) + +hybrid_stack_spec = ModuleSpec( + module=HybridStack, + submodules=HybridStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + params={ + "expand": 1, + "d_conv": 4, + }, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TELayerNormColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + moe_layer=ModuleSpec( + # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + ), + ), + ), +) diff --git a/primus/backends/megatron/megatron_pretrain_trainer.py b/primus/backends/megatron/megatron_pretrain_trainer.py index 02f97e626..97ce5e869 100644 --- a/primus/backends/megatron/megatron_pretrain_trainer.py +++ b/primus/backends/megatron/megatron_pretrain_trainer.py @@ -19,13 +19,30 @@ def train(self): from megatron.core.enums import ModelType from megatron.training import pretrain # type: ignore - from pretrain_gpt import ( # type: ignore - forward_step, - train_valid_test_datasets_provider, - ) from primus.core.utils.import_utils import get_model_provider + # Determine model type (gpt or mamba) from backend_args + model_type = getattr(self.backend_args, "model_type", "gpt") + log_rank_0(f"-detected model_type: {model_type}") + + # Import the appropriate training components based on model_type + if model_type == "mamba": + from pretrain_mamba import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + + log_rank_0("Using Mamba model provider and training components") + else: + from pretrain_gpt import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + + log_rank_0("Using GPT model provider and training components") + + # Configure training components if hasattr(train_valid_test_datasets_provider, "is_distributed"): train_valid_test_datasets_provider.is_distributed = True @@ -49,9 +66,17 @@ def train(self): if "store" in sig.parameters: kwargs["store"] = store + # Get model provider with correct model_type + # Only pass model_type if it's not the default to maintain compatibility + if model_type != "gpt": + model_provider = get_model_provider(model_type=model_type) + else: + model_provider = get_model_provider() + log_rank_0(f"-model_provider: {model_provider}") + wrapped_pretrain( train_valid_test_datasets_provider, - get_model_provider(), + model_provider, ModelType.encoder_or_decoder, forward_step, **kwargs, diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index 8c6f24e48..50296107e 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -8,6 +8,7 @@ extends: # model architecture use_legacy_models: false deprecated_use_mcore_models: false +model_type: gpt # gpt or mamba num_layers: 24 encoder_num_layers: null decoder_num_layers: null @@ -21,6 +22,7 @@ num_query_groups: null add_position_embedding: false position_embedding_type: learned_absolute max_position_embeddings: null +original_max_position_embeddings: null untie_embeddings_and_output_weights: true ffn_hidden_size: null @@ -99,6 +101,21 @@ rotary_scaling_factor: 1.0 # float mscale: 1.0 # float mscale_all_dim: 1.0 # float +# Mamba layer configuration +mamba_state_dim: 128 +mamba_head_dim: 64 +mamba_num_groups: 8 +mamba_num_heads: null +mamba_expand: 2 +mamba_d_conv: 4 +disable_mamba_mem_eff_path: false + +# Hybrid model configuration +is_hybrid_model: false # bool +hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +hybrid_override_pattern: null # str + # MTP mtp_num_layers: null # int mtp_loss_scaling_factor: 0.1 # float diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml new file mode 100644 index 000000000..6665da3af --- /dev/null +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -0,0 +1,16 @@ +bases: + - mamba_base.yaml + +# Mamba 370M configuration +model_type: mamba # CRITICAL: Mamba models must use mamba model type +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 1024 +num_attention_heads: 16 # Required by Megatron validation, even for pure Mamba models +ffn_hidden_size: null +mamba_state_dim: 16 +mamba_head_dim: 64 +mamba_num_groups: 8 diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml new file mode 100644 index 000000000..d52fe6db2 --- /dev/null +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -0,0 +1,36 @@ +bases: + - language_model.yaml + +# Mamba-specific configuration +# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) +# must be set in the pretrain config overrides, not here + +model_type: mamba +use_legacy_models: false + +# Position embeddings - Mamba typically doesn't use position embeddings +position_embedding_type: rope +use_rotary_position_embeddings: false + +# Tokenizer (should be set in specific model configs) +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: null + +# Standard transformer settings that may be used by hybrid models +is_hybrid_model: false +attention_dropout: 0.0 +hidden_dropout: 0.0 + +# Embeddings +untie_embeddings_and_output_weights: false + +# Other settings +apply_residual_connection_post_layernorm: false +add_bias_linear: false +swiglu: false + +# Normalization +norm_epsilon: 1.0e-5 + +# Initialization +init_method_std: 0.02 diff --git a/primus/configs/models/megatron/zebra_llama_1B.yaml b/primus/configs/models/megatron/zebra_llama_1B.yaml new file mode 100644 index 000000000..d1afa9531 --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_1B.yaml @@ -0,0 +1,42 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 32 +hidden_size: 2048 +ffn_hidden_size: 8192 + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 64 +mamba_head_dim: 64 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 1344 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 32 # Query-Key head dimension +qk_pos_emb_head_dim: 32 # Positional embedding head dimension +v_head_dim: 64 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml new file mode 100644 index 000000000..23090841f --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -0,0 +1,43 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 56 +hidden_size: 3072 +ffn_hidden_size: 8192 +normalization: "RMSNorm" + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 24 +q_lora_rank: 1536 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +original_max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml new file mode 100644 index 000000000..0237a652d --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -0,0 +1,43 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 64 +hidden_size: 4096 +ffn_hidden_size: 14436 +normalization: "RMSNorm" + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 2048 # Query LoRA rank +kv_lora_rank: 160 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +original_max_position_embeddings: 131072 diff --git a/primus/configs/modules/megatron/trainer_base.yaml b/primus/configs/modules/megatron/trainer_base.yaml index 76be0c7ba..ba89bd7d1 100755 --- a/primus/configs/modules/megatron/trainer_base.yaml +++ b/primus/configs/modules/megatron/trainer_base.yaml @@ -305,17 +305,17 @@ rerun_mode: disabled # str: 'disabled', 'validate_results', 'report_stats' # Experimental features enable_experimental: false -# Hybrid model configuration -hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] -hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] -hybrid_override_pattern: null # str - -# Mamba layer configuration -mamba_state_dim: 128 -mamba_head_dim: 64 -mamba_num_groups: 8 -mamba_num_heads: null -disable_mamba_mem_eff_path: false +# # Hybrid model configuration +# hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_override_pattern: null # str + +# # Mamba layer configuration +# mamba_state_dim: 128 +# mamba_head_dim: 64 +# mamba_num_groups: 8 +# mamba_num_heads: null +# disable_mamba_mem_eff_path: false # Args of precision-aware optimizer use_precision_aware_optimizer: false @@ -405,7 +405,6 @@ indexer_log_interval: 1000 enable_ft_package: false calc_ft_timeouts: false run_workload_inspector_server: false -is_hybrid_model: false heterogeneous_layers_config_path: null heterogeneous_layers_config_encoded_json: null diff --git a/primus/core/utils/import_utils.py b/primus/core/utils/import_utils.py index 2ccd8ebed..34c2de16d 100644 --- a/primus/core/utils/import_utils.py +++ b/primus/core/utils/import_utils.py @@ -34,25 +34,42 @@ def lazy_import(paths, symbol, log_prefix="[Primus]"): raise ImportError(f"{log_prefix} {symbol} not found in any of: {paths}") -def get_model_provider(): +def get_model_provider(model_type="gpt"): """ - Resolve model_provider across Megatron versions. + Resolve model_provider across Megatron versions and model types. - - New: model_provider + gpt_builder + Args: + model_type (str): Type of model - 'gpt' or 'mamba'. Defaults to 'gpt'. + + - New: model_provider + gpt_builder/mamba_builder - Mid: model_provider only - - Old: pretrain_gpt.model_provider + - Old: pretrain_gpt.model_provider / pretrain_mamba.model_provider """ # Try to import model_provider - model_provider = lazy_import( - ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" - ) + if model_type == "mamba": + model_provider = lazy_import( + ["model_provider", "pretrain_mamba"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) + # Try to import mamba_builder (for Mamba models) + try: + mamba_builder = lazy_import( + ["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]" + ) + return partial(model_provider, mamba_builder) + except ImportError: + return model_provider + else: + # Default GPT behavior + model_provider = lazy_import( + ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) - # Try to import gpt_builder (only exists in newer versions) - try: - gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") - return partial(model_provider, gpt_builder) - except ImportError: - return model_provider + # Try to import gpt_builder (only exists in newer versions) + try: + gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, gpt_builder) + except ImportError: + return model_provider def get_custom_fsdp(): diff --git a/primus/modules/trainer/lightmegatron/pre_trainer.py b/primus/modules/trainer/lightmegatron/pre_trainer.py index 0a12460fa..eafa75d5a 100644 --- a/primus/modules/trainer/lightmegatron/pre_trainer.py +++ b/primus/modules/trainer/lightmegatron/pre_trainer.py @@ -37,16 +37,38 @@ def run(self, *args, **kwargs): log_rank_0("run light-megatron") from megatron.core.enums import ModelType - from megatron.training import inprocess_restart, pretrain - from pretrain_gpt import forward_step, train_valid_test_datasets_provider - - train_valid_test_datasets_provider.is_distributed = True - wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) - - wrapped_pretrain( - train_valid_test_datasets_provider, - get_model_provider(), - ModelType.encoder_or_decoder, - forward_step, - store=store, - ) + from megatron.training import get_args, inprocess_restart, pretrain + + # Determine model type from config (gpt or mamba) + megatron_args = get_args() + model_type = getattr(megatron_args, "model_type", "gpt") + log_rank_0(f"Detected model_type: {model_type}") + + if model_type == "mamba": + # Import from pretrain_mamba + from pretrain_mamba import forward_step, train_valid_test_datasets_provider + + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type="mamba"), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) + else: + # Default to GPT + from pretrain_gpt import forward_step, train_valid_test_datasets_provider + + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type="gpt"), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) diff --git a/primus/modules/trainer/megatron/pre_trainer.py b/primus/modules/trainer/megatron/pre_trainer.py index 1c8539dd5..d34788b04 100644 --- a/primus/modules/trainer/megatron/pre_trainer.py +++ b/primus/modules/trainer/megatron/pre_trainer.py @@ -242,6 +242,20 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals assert ( args.overlap_moe_expert_parallel_comm ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + + # Schedule plan building is only supported for GPT models + # Check if this is a Mamba model + unwrapped_model = model + while hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if "Mamba" in model_class_name: + raise NotImplementedError( + "Schedule plan building is not supported for Mamba models. " + "Please disable overlap_moe_expert_parallel_comm for Mamba." + ) + if args.patch_moe_overlap: assert ( not args.delay_wgrad_compute @@ -267,8 +281,21 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals ) return schedule_plan, partial(self.loss_func, loss_mask) else: - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask - ) + # Check if model supports loss_mask parameter + # MambaModel doesn't accept loss_mask, but GPTModel does + # Unwrap the model to get the actual model class + unwrapped_model = model + while hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if "Mamba" in model_class_name: + # MambaModel doesn't accept loss_mask parameter + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + else: + # GPTModel and other models accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) return output_tensor, partial(self.loss_func, loss_mask) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 5db59188b..8f0c9f70e 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -483,8 +483,10 @@ def update_primus_config( if args.iterations_to_skip is None: args.iterations_to_skip = [] - # support moe_freq_type - if isinstance(args.moe_layer_freq, str): + # support moe_freq_type - ensure moe_layer_freq has a default value + if not hasattr(args, "moe_layer_freq"): + args.moe_layer_freq = 1 + elif isinstance(args.moe_layer_freq, str): try: args.moe_layer_freq = eval(args.moe_layer_freq) except Exception: @@ -496,11 +498,35 @@ def update_primus_config( args.valid_data_path = None args.test_data_path = None + # Determine model type (gpt or mamba) + model_type = getattr(args, "model_type", "gpt") + log_rank_0(f"-detected model_type: {model_type}") + + # Ensure required attributes have safe defaults if missing from config + if not hasattr(args, "final_logit_softcapping"): + args.final_logit_softcapping = None + if not hasattr(args, "router_logit_softcapping"): + args.router_logit_softcapping = None + + # Only pass model_type parameter when it's "mamba" to maintain backward compatibility + # with main branch behavior for "gpt" (default) case if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial(primus_model_provider, get_model_provider()) + if model_type == "mamba": + self.model_provider = functools.partial( + primus_model_provider, get_model_provider(model_type=model_type) + ) + else: + self.model_provider = functools.partial(primus_model_provider, get_model_provider()) else: - self.model_provider = get_model_provider() + if model_type == "mamba": + log_rank_0(f"-getting model provider for model_type={model_type}") + model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-model_provider: {model_provider}") + self.model_provider = model_provider + else: + # For "gpt" (default), call without arguments to match main branch behavior + self.model_provider = get_model_provider() if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}") @@ -867,6 +893,8 @@ def setup_model_and_optimizer( log_rank_0(f"use te backend...") log_rank_0(f"-run get_model") + log_rank_0(f"-model_provider_func: {model_provider_func}") + log_rank_0(f"-model_type: {model_type}") model = get_model(model_provider_func, model_type) log_rank_0(model) # get_megatron_optimizer will use the ddp_config diff --git a/third_party/mamba b/third_party/mamba new file mode 160000 index 000000000..4d67c534d --- /dev/null +++ b/third_party/mamba @@ -0,0 +1 @@ +Subproject commit 4d67c534d10b7299194ede55b36718cc7a1dd472