From 3de16a09474a3cb2a1eb1f2a0bf28f7c2e220857 Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:40:31 -0700 Subject: [PATCH 01/32] Disable torchtitan activation checkpointing for better 8B model performance --- examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml b/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml index a879acf86..0959bed8f 100644 --- a/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml @@ -24,12 +24,12 @@ modules: warmup_steps: 10 training: - batch_size: 19 + batch_size: 3 compile: true steps: 50 activation_checkpoint: - mode: "full" # ["none", "selective", "full"] + mode: "none" # ["none", "selective", "full"] selective_ac_option: "op" # "int" = ac every positive int layer or 'op', ac based on ops policy primus_turbo: From 7d6d52b50bad529ef6ad3e86a42e97ef69d60b96 Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:41:02 -0700 Subject: [PATCH 02/32] Disable torchtitan activation checkpointing for better 8B model performance --- examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml index d0cf4dfdd..3a6bd91da 100644 --- a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml @@ -23,12 +23,12 @@ modules: log_freq: 1 training: - batch_size: 19 + batch_size: 8 compile: true steps: 50 activation_checkpoint: - mode: "full" # ["none", "selective", "full"] + mode: "none" # ["none", "selective", "full"] selective_ac_option: "op" # "int" = ac every positive int layer or 'op', ac based on ops policy float8: From 18cd0bea7f35fbd251fba46da3d1a4baa52ce6f9 Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:42:36 -0700 Subject: [PATCH 03/32] Set default for MI300 best perf. --- examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml index 3a6bd91da..03f409781 100644 --- a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml @@ -23,7 +23,7 @@ modules: log_freq: 1 training: - batch_size: 8 + batch_size: 4 compile: true steps: 50 From 6075971a14e14029dcfe2d77c2c584ee076e0aad Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:54:56 -0700 Subject: [PATCH 04/32] Update llama3.1_70B-BF16-pretrain.yaml drop log frequency to reduce log perf. overhead --- examples/torchtitan/configs/llama3.1_70B-BF16-pretrain.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/llama3.1_70B-BF16-pretrain.yaml index 081e999dc..7b0fd9118 100644 --- a/examples/torchtitan/configs/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_70B-BF16-pretrain.yaml @@ -20,7 +20,7 @@ modules: warmup_steps: 10 metrics: - log_freq: 1 + log_freq: 10 training: batch_size: 4 From 464d1ae9be701f4cacdc00b822304f450bb6cec3 Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:55:11 -0700 Subject: [PATCH 05/32] Update llama3.1_70B-FP8-pretrain.yaml drop log frequency to reduce log perf. overhead --- examples/torchtitan/configs/llama3.1_70B-FP8-pretrain.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/llama3.1_70B-FP8-pretrain.yaml index c47508dab..a82e66535 100644 --- a/examples/torchtitan/configs/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_70B-FP8-pretrain.yaml @@ -20,7 +20,7 @@ modules: warmup_steps: 10 metrics: - log_freq: 1 + log_freq: 10 training: batch_size: 3 From 05ed2ffb5012b8b2b8f2210e72a3dceec3c2782b Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:55:35 -0700 Subject: [PATCH 06/32] Update llama3.1_8B-BF16-pretrain.yaml drop log frequency to reduce log perf. overhead --- examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml b/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml index 0959bed8f..66f5a2553 100644 --- a/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_8B-BF16-pretrain.yaml @@ -16,7 +16,7 @@ modules: stderr_sink_level: INFO metrics: - log_freq: 1 + log_freq: 10 enable_wandb: false lr_scheduler: From aab4234b6d19c98c3c7f5f0337069c2e1462dc36 Mon Sep 17 00:00:00 2001 From: alfuyao1986 <81382865+alfuyao1986@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:55:55 -0700 Subject: [PATCH 07/32] Update llama3.1_8B-FP8-pretrain.yaml drop log frequency to reduce log perf. overhead --- examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml index 03f409781..f694fd54e 100644 --- a/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/llama3.1_8B-FP8-pretrain.yaml @@ -20,7 +20,7 @@ modules: warmup_steps: 10 metrics: - log_freq: 1 + log_freq: 10 training: batch_size: 4 From d1ec7870f68ebdb827eff52bfa47bd38954568fd Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Mon, 6 Oct 2025 21:39:09 +0000 Subject: [PATCH 08/32] tw script update --- examples/run_slurm_pretrain.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index fd2f6068e..59b42cd0f 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -40,7 +40,10 @@ mkdir -p "$LOG_DIR" srun -N "${NNODES}" \ --exclusive \ --ntasks-per-node=1 \ - --cpus-per-task="${CPUS_PER_TASK:-256}" \ + --cpus-per-task="${CPUS_PER_TASK:-128}" \ + --partition=AIG_Models \ + --gres=gpu:8 \ + -t 03:00:00 \ bash -c " readarray -t node_array < <(scontrol show hostnames \"\$SLURM_JOB_NODELIST\") if [ \"\$SLURM_NODEID\" = \"0\" ]; then @@ -68,5 +71,7 @@ srun -N "${NNODES}" \ export TORCHTITAN_PATH=\${TORCHTITAN_PATH} export BACKEND_PATH=\${BACKEND_PATH} export PATH_TO_BNXT_TAR_PACKAGE=\${PATH_TO_BNXT_TAR_PACKAGE} + docker ps + rocm-smi bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE} " bash "$@" From 5fcbe8562d92ef497886b386212d95fa863a045f Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Wed, 8 Oct 2025 14:41:46 -0700 Subject: [PATCH 09/32] remove cluster-specific commands --- examples/run_slurm_pretrain.sh | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 59b42cd0f..fd2f6068e 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -40,10 +40,7 @@ mkdir -p "$LOG_DIR" srun -N "${NNODES}" \ --exclusive \ --ntasks-per-node=1 \ - --cpus-per-task="${CPUS_PER_TASK:-128}" \ - --partition=AIG_Models \ - --gres=gpu:8 \ - -t 03:00:00 \ + --cpus-per-task="${CPUS_PER_TASK:-256}" \ bash -c " readarray -t node_array < <(scontrol show hostnames \"\$SLURM_JOB_NODELIST\") if [ \"\$SLURM_NODEID\" = \"0\" ]; then @@ -71,7 +68,5 @@ srun -N "${NNODES}" \ export TORCHTITAN_PATH=\${TORCHTITAN_PATH} export BACKEND_PATH=\${BACKEND_PATH} export PATH_TO_BNXT_TAR_PACKAGE=\${PATH_TO_BNXT_TAR_PACKAGE} - docker ps - rocm-smi bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE} " bash "$@" From e16b27bf6c1b2798f38848fc574fee60d9a9b902 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Thu, 9 Oct 2025 11:46:36 -0700 Subject: [PATCH 10/32] update common perf arguments - ce fusion - moe gemms --- examples/megatron/configs/deepseek_v2_lite-pretrain.yaml | 8 ++++++-- examples/megatron/configs/deepseek_v3-pretrain.yaml | 6 +++++- examples/megatron/configs/llama2_70B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama2_7B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama3.1_70B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama3.1_8B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama3.3_70B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama3_70B-pretrain.yaml | 4 ++++ examples/megatron/configs/llama3_8B-pretrain.yaml | 4 ++++ .../megatron/configs/mixtral_8x22B_v0.1-pretrain.yaml | 6 +++++- examples/megatron/configs/mixtral_8x7B_v0.1-pretrain.yaml | 6 +++++- examples/megatron/configs/qwen2.5_72B-pretrain.yaml | 4 ++++ examples/megatron/configs/qwen2.5_7B-pretrain.yaml | 4 ++++ 13 files changed, 57 insertions(+), 5 deletions(-) diff --git a/examples/megatron/configs/deepseek_v2_lite-pretrain.yaml b/examples/megatron/configs/deepseek_v2_lite-pretrain.yaml index a4eb04866..4318cdfbf 100644 --- a/examples/megatron/configs/deepseek_v2_lite-pretrain.yaml +++ b/examples/megatron/configs/deepseek_v2_lite-pretrain.yaml @@ -57,7 +57,7 @@ modules: # fused wgrad gemm and accumulation gradient_accumulation_fusion: false # recommend set `false` in fp8 - moe_use_legacy_grouped_gemm: true + moe_use_legacy_grouped_gemm: false # fused topk router with aux score moe_use_fused_router_with_aux_score: false # pad 192/128 for deepseek attention @@ -82,4 +82,8 @@ modules: # Turbo enable_primus_turbo: true use_turbo_attention: true - use_turbo_grouped_mlp: true \ No newline at end of file + use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/deepseek_v3-pretrain.yaml b/examples/megatron/configs/deepseek_v3-pretrain.yaml index c203c6791..76ae80972 100644 --- a/examples/megatron/configs/deepseek_v3-pretrain.yaml +++ b/examples/megatron/configs/deepseek_v3-pretrain.yaml @@ -57,7 +57,7 @@ modules: # fused wgrad gemm and accumulation gradient_accumulation_fusion: false # recommend set `false` in fp8 - moe_use_legacy_grouped_gemm: true + moe_use_legacy_grouped_gemm: false # fused topk router with aux score moe_use_fused_router_with_aux_score: false # pad 192/128 for deepseek attention @@ -85,3 +85,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/llama2_70B-pretrain.yaml b/examples/megatron/configs/llama2_70B-pretrain.yaml index 075816231..8a810d9bd 100755 --- a/examples/megatron/configs/llama2_70B-pretrain.yaml +++ b/examples/megatron/configs/llama2_70B-pretrain.yaml @@ -75,3 +75,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/llama2_7B-pretrain.yaml b/examples/megatron/configs/llama2_7B-pretrain.yaml index 8e6f0d784..5ad882bb1 100755 --- a/examples/megatron/configs/llama2_7B-pretrain.yaml +++ b/examples/megatron/configs/llama2_7B-pretrain.yaml @@ -78,3 +78,7 @@ modules: # overlap_param_gather: false # ckpt_format: torch # sequence_parallel: 1 + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/llama3.1_70B-pretrain.yaml b/examples/megatron/configs/llama3.1_70B-pretrain.yaml index ed7c9f202..593e9b9da 100644 --- a/examples/megatron/configs/llama3.1_70B-pretrain.yaml +++ b/examples/megatron/configs/llama3.1_70B-pretrain.yaml @@ -74,3 +74,7 @@ modules: # Turbo enable_primus_turbo: true use_turbo_attention: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/llama3.1_8B-pretrain.yaml b/examples/megatron/configs/llama3.1_8B-pretrain.yaml index 68821f9e5..dcd831f3d 100644 --- a/examples/megatron/configs/llama3.1_8B-pretrain.yaml +++ b/examples/megatron/configs/llama3.1_8B-pretrain.yaml @@ -65,3 +65,7 @@ modules: no_save_rng: null disable_last_saving: true ckpt_format: torch + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/llama3.3_70B-pretrain.yaml b/examples/megatron/configs/llama3.3_70B-pretrain.yaml index 5a6f98c28..0e1778c16 100644 --- a/examples/megatron/configs/llama3.3_70B-pretrain.yaml +++ b/examples/megatron/configs/llama3.3_70B-pretrain.yaml @@ -75,3 +75,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/llama3_70B-pretrain.yaml b/examples/megatron/configs/llama3_70B-pretrain.yaml index 97c29b2dc..fa10e17b9 100755 --- a/examples/megatron/configs/llama3_70B-pretrain.yaml +++ b/examples/megatron/configs/llama3_70B-pretrain.yaml @@ -75,3 +75,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/llama3_8B-pretrain.yaml b/examples/megatron/configs/llama3_8B-pretrain.yaml index 21168120e..45a0a0de0 100644 --- a/examples/megatron/configs/llama3_8B-pretrain.yaml +++ b/examples/megatron/configs/llama3_8B-pretrain.yaml @@ -71,3 +71,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/mixtral_8x22B_v0.1-pretrain.yaml b/examples/megatron/configs/mixtral_8x22B_v0.1-pretrain.yaml index 9457394b6..d567d821e 100644 --- a/examples/megatron/configs/mixtral_8x22B_v0.1-pretrain.yaml +++ b/examples/megatron/configs/mixtral_8x22B_v0.1-pretrain.yaml @@ -59,7 +59,7 @@ modules: # fusion moe_permute_fusion: false gradient_accumulation_fusion: false - moe_use_legacy_grouped_gemm: true + moe_use_legacy_grouped_gemm: false # ckpt finetune: false @@ -73,3 +73,7 @@ modules: no_save_rng: null disable_last_saving: true ckpt_format: torch + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/mixtral_8x7B_v0.1-pretrain.yaml b/examples/megatron/configs/mixtral_8x7B_v0.1-pretrain.yaml index c118a8220..75c2dc889 100644 --- a/examples/megatron/configs/mixtral_8x7B_v0.1-pretrain.yaml +++ b/examples/megatron/configs/mixtral_8x7B_v0.1-pretrain.yaml @@ -54,7 +54,7 @@ modules: # fusion moe_permute_fusion: false gradient_accumulation_fusion: false - moe_use_legacy_grouped_gemm: true + moe_use_legacy_grouped_gemm: false # ckpt finetune: false @@ -68,3 +68,7 @@ modules: no_save_rng: null disable_last_saving: true ckpt_format: torch + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true \ No newline at end of file diff --git a/examples/megatron/configs/qwen2.5_72B-pretrain.yaml b/examples/megatron/configs/qwen2.5_72B-pretrain.yaml index bfd27afc9..35f2d39f2 100644 --- a/examples/megatron/configs/qwen2.5_72B-pretrain.yaml +++ b/examples/megatron/configs/qwen2.5_72B-pretrain.yaml @@ -79,3 +79,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/qwen2.5_7B-pretrain.yaml b/examples/megatron/configs/qwen2.5_7B-pretrain.yaml index 644dd2a21..0451bd4f5 100644 --- a/examples/megatron/configs/qwen2.5_7B-pretrain.yaml +++ b/examples/megatron/configs/qwen2.5_7B-pretrain.yaml @@ -72,3 +72,7 @@ modules: enable_primus_turbo: true use_turbo_attention: true use_turbo_grouped_mlp: true + + # Cross entropy flags + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true From ba187e46be35a8d85c09bd082017b45038b5ec45 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Tue, 18 Nov 2025 23:51:18 -0600 Subject: [PATCH 11/32] Revert "refactor(torchtitan): rollback Titan to 99c0cb2(20250907) and stabilize trainer UTs (#262)" This reverts commit 1e2e1b1acb178855cf512de802d494b09c668431. --- .github/workflows/docker/Dockerfile | 2 +- .../MI300X/deepseek_v3_16b-pretrain.yaml | 4 - .../MI300X/deepseek_v3_671b-pretrain.yaml | 2 +- .../torchtitan/models/deepseek_v3/__init__.py | 0 .../models/deepseek_v3/model/__init__.py | 0 .../models/deepseek_v3/model/model.py | 63 ------- .../torchtitan/models/llama3/model/model.py | 15 +- .../primus_turbo_converter.py | 7 +- .../models/torchtitan/llama3.1_70B-fp8.yaml | 3 +- .../models/torchtitan/llama3.1_8B-fp8.yaml | 3 +- .../models/torchtitan/llama3.3_70B-fp8.yaml | 3 +- .../models/torchtitan/llama3_70B-fp8.yaml | 3 +- .../models/torchtitan/llama3_8B-fp8.yaml | 3 +- .../modules/trainer/torchtitan/patch_utils.py | 42 +---- .../modules/trainer/torchtitan/pre_trainer.py | 115 ++++-------- .../trainer/torchtitan/test_patch_utils.py | 41 +---- tests/trainer/test_torchtitan_trainer.py | 167 ++++++++---------- third_party/torchtitan | 2 +- 18 files changed, 144 insertions(+), 331 deletions(-) delete mode 100644 primus/backends/torchtitan/models/deepseek_v3/__init__.py delete mode 100644 primus/backends/torchtitan/models/deepseek_v3/model/__init__.py delete mode 100644 primus/backends/torchtitan/models/deepseek_v3/model/model.py diff --git a/.github/workflows/docker/Dockerfile b/.github/workflows/docker/Dockerfile index 6e108880a..0dc80e0d3 100644 --- a/.github/workflows/docker/Dockerfile +++ b/.github/workflows/docker/Dockerfile @@ -1,6 +1,6 @@ # Base image # FROM docker.io/rocm/megatron-lm:v25.9_gfx942 -FROM docker.io/rocm/primus:v25.9_gfx942 +FROM docker.io/rocm/pyt-megatron-lm-jax-nightly-private:pytorch_rocm7.0_20251024 # Specify the commit of Primus-Turbo when building: docker build --build-arg PRIMUS_TURBO_COMMIT=xxx .) ARG PRIMUS_TURBO_COMMIT diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml index 1aedd9cbc..95e1a80cf 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml @@ -71,10 +71,6 @@ modules: enable: true components: ["loss"] # ["model", "loss"] - primus_turbo: - enable_primus_turbo: true - enable_attention_float8: false - # quantize: # linear: # float8: diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml index 84d7de558..b340538be 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml @@ -52,7 +52,7 @@ modules: enable_async_tensor_parallel: false pipeline_parallel_degree: 1 pipeline_parallel_schedule: "Interleaved1F1B" - expert_parallel_degree: 8 + expert_parallel_degree: 1 expert_tensor_parallel_degree: 1 checkpoint: diff --git a/primus/backends/torchtitan/models/deepseek_v3/__init__.py b/primus/backends/torchtitan/models/deepseek_v3/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/__init__.py b/primus/backends/torchtitan/models/deepseek_v3/model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/model.py b/primus/backends/torchtitan/models/deepseek_v3/model/model.py deleted file mode 100644 index b1f442c23..000000000 --- a/primus/backends/torchtitan/models/deepseek_v3/model/model.py +++ /dev/null @@ -1,63 +0,0 @@ -############################################################################### -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -# -# See LICENSE for license information. -############################################################################### - -import torch -from torchtitan.models.deepseek_v3.model.model import Attention as TTAttention -from torchtitan.models.deepseek_v3.model.model import apply_rotary_emb - - -class Attention(TTAttention): - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """ - Forward pass for the Multi-Head Latent Attention (MLA) Layer. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. - - Returns: - torch.Tensor: Output tensor with the same shape as the input. - """ - bsz, seqlen, _ = x.size() - - # Query projection - if self.q_lora_rank == 0: - q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) - else: - q = self.wq_a(x) - q = self.wq_b(self.q_norm(q)) - # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual - # local heads from sizes of q and kv as TP may have sharded them after - # the above linear ops. - q = q.view(bsz, seqlen, -1, self.qk_head_dim) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_pe = apply_rotary_emb(q_pe, freqs_cis) - q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) - - # Key-value projection - kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) - kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - - k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # (bsz, seqlen, 1, qk_rope_head_dim) - - kv = self.wkv_b(self.kv_norm(kv)) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) - kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat( - [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 - ) # (bsz, seqlen, n_heads, qk_head_dim) - - q = q.view(bsz, seqlen, -1, self.qk_head_dim) - k = k.view(bsz, seqlen, -1, self.qk_head_dim) - v = v.view(bsz, seqlen, -1, self.v_head_dim) - - output = self.sdpa(q, k, v) - output = output.view(bsz, seqlen, -1) - return self.wo(output) diff --git a/primus/backends/torchtitan/models/llama3/model/model.py b/primus/backends/torchtitan/models/llama3/model/model.py index 7187f1d9d..5dee6e34f 100644 --- a/primus/backends/torchtitan/models/llama3/model/model.py +++ b/primus/backends/torchtitan/models/llama3/model/model.py @@ -5,16 +5,20 @@ ############################################################################### import torch - -# from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.flex_attention import BlockMask from torchtitan.models.llama3.model.model import Attention as TTAttention from torchtitan.models.llama3.model.model import apply_rotary_emb -# AttentionMasksType = dict[str, BlockMask] | BlockMask +AttentionMasksType = dict[str, BlockMask] | BlockMask class Attention(TTAttention): - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -31,8 +35,7 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): # xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) # xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - # output = self.inner_attention(xq, xk, xv) - output = self.sdpa(xq, xk, xv) + output = self.inner_attention(xq, xk, xv) output = output.view(bs, seqlen, -1) return self.wo(output) diff --git a/primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py b/primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py index 70ff8a18a..991b4816f 100644 --- a/primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py +++ b/primus/backends/torchtitan/primus_turbo_extensions/primus_turbo_converter.py @@ -7,7 +7,10 @@ import torch from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.models.attention import FlexAttention, ScaledDotProductAttention +from torchtitan.models.attention import ( + FlexAttentionWrapper, + ScaledDotProductAttentionWrapper, +) from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, @@ -18,7 +21,7 @@ def replace_turbo_attention_modules(model: torch.nn.Module, backend_type: str, u from primus_turbo.pytorch.modules import TurboAttention # TODO: import Check for name, module in model.named_children(): - if isinstance(module, (FlexAttention, ScaledDotProductAttention)): + if isinstance(module, (FlexAttentionWrapper, ScaledDotProductAttentionWrapper)): setattr( model, name, diff --git a/primus/configs/models/torchtitan/llama3.1_70B-fp8.yaml b/primus/configs/models/torchtitan/llama3.1_70B-fp8.yaml index f663b0fa5..c15403382 100644 --- a/primus/configs/models/torchtitan/llama3.1_70B-fp8.yaml +++ b/primus/configs/models/torchtitan/llama3.1_70B-fp8.yaml @@ -7,4 +7,5 @@ model: flavor: "70B" hf_assets_path: "meta-llama/Llama-3.1-8B" converters: - - "float8" + - quantize.linear.float8 + - quantize.grouped_mm.float8 diff --git a/primus/configs/models/torchtitan/llama3.1_8B-fp8.yaml b/primus/configs/models/torchtitan/llama3.1_8B-fp8.yaml index 3a2e1abe1..9058d8c81 100644 --- a/primus/configs/models/torchtitan/llama3.1_8B-fp8.yaml +++ b/primus/configs/models/torchtitan/llama3.1_8B-fp8.yaml @@ -7,4 +7,5 @@ model: flavor: "8B" hf_assets_path: "meta-llama/Llama-3.1-8B" converters: - - "float8" + - quantize.linear.float8 + - quantize.grouped_mm.float8 diff --git a/primus/configs/models/torchtitan/llama3.3_70B-fp8.yaml b/primus/configs/models/torchtitan/llama3.3_70B-fp8.yaml index 033085ad0..30ebb9ef0 100644 --- a/primus/configs/models/torchtitan/llama3.3_70B-fp8.yaml +++ b/primus/configs/models/torchtitan/llama3.3_70B-fp8.yaml @@ -7,4 +7,5 @@ model: flavor: "70B" hf_assets_path: "meta-llama/Llama-3.3-70B-Instruct" converters: - - "float8" + - quantize.linear.float8 + - quantize.grouped_mm.float8 diff --git a/primus/configs/models/torchtitan/llama3_70B-fp8.yaml b/primus/configs/models/torchtitan/llama3_70B-fp8.yaml index 28d5d4c40..0a8c2b693 100644 --- a/primus/configs/models/torchtitan/llama3_70B-fp8.yaml +++ b/primus/configs/models/torchtitan/llama3_70B-fp8.yaml @@ -7,4 +7,5 @@ model: flavor: "70B" hf_assets_path: "meta-llama/Meta-Llama-3-70B" converters: - - "float8" + - quantize.linear.float8 + - quantize.grouped_mm.float8 diff --git a/primus/configs/models/torchtitan/llama3_8B-fp8.yaml b/primus/configs/models/torchtitan/llama3_8B-fp8.yaml index 010555e77..95b9eba97 100644 --- a/primus/configs/models/torchtitan/llama3_8B-fp8.yaml +++ b/primus/configs/models/torchtitan/llama3_8B-fp8.yaml @@ -7,4 +7,5 @@ model: flavor: "8B" hf_assets_path: "meta-llama/Meta-Llama-3-8B" converters: - - "float8" + - quantize.linear.float8 + - quantize.grouped_mm.float8 diff --git a/primus/modules/trainer/torchtitan/patch_utils.py b/primus/modules/trainer/torchtitan/patch_utils.py index ce256efc3..dfb57503a 100644 --- a/primus/modules/trainer/torchtitan/patch_utils.py +++ b/primus/modules/trainer/torchtitan/patch_utils.py @@ -4,9 +4,6 @@ # See LICENSE for license information. ############################################################################### -import inspect -from functools import wraps - import numpy as np from datasets import Dataset @@ -44,18 +41,20 @@ def _create_mock_token_dataset( def patch_mock_hf_dataset() -> None: - from primus.core.utils.logger import _logger as logger + from primus.core.utils import logger try: import datasets + logger.warning("[Primus Mock] Enabling mock HuggingFace dataset mode.") + def mock_load_dataset(path: str, *args, **kwargs) -> Dataset: """ Replacement for datasets.load_dataset(). Intercepts Titan calls like load_dataset('allenai/c4', ...). Returns a fake Dataset of text samples. """ - logger.warning(f"[PrimusPatch][MockDataset] load_dataset('{path}') is mocked.") + logger.warning(f"[Primus Mock] load_dataset('{path}') is mocked.") # Shorter dataset for validation split if "validation" in path.lower(): return _create_mock_text_dataset(num_samples=32) @@ -63,36 +62,7 @@ def mock_load_dataset(path: str, *args, **kwargs) -> Dataset: return _create_mock_token_dataset(seq_len=8192, vocab_size=32000, num_samples=256) datasets.load_dataset = mock_load_dataset - logger.warning("[PrimusPatch][Dataset] Patched datasets.load_dataset successfully.") - - except Exception as e: - logger.error(f"[PrimusPatch][Dataset] Failed to patch datasets.load_dataset: {e}") - + logger.warning("[PrimusPath][Dataset] Patched datasets.load_dataset successfully.") -def apply_patch_checkpoint_wrapper(): - """ - Patch torch.distributed.algorithms._checkpoint.checkpoint_wrapper - to ignore unsupported kwargs such as `early_stop`. - """ - from primus.core.utils.logger import _logger as logger - - try: - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as ckpt_mod - - orig_fn = ckpt_mod.checkpoint_wrapper - - @wraps(orig_fn) - def safe_checkpoint_wrapper(*args, **kwargs): - sig = inspect.signature(orig_fn) - valid = set(sig.parameters.keys()) - dropped = [] - for k in list(kwargs.keys()): - if k not in valid: - kwargs.pop(k) - dropped.append(k) - return orig_fn(*args, **kwargs) - - ckpt_mod.checkpoint_wrapper = safe_checkpoint_wrapper - logger.warning("[PrimusPatch][Checkpoint] checkpoint_wrapper patched successfully") except Exception as e: - logger.warning(f"[PrimusPatch][Checkpoint] Failed to patch checkpoint_wrapper: {e}") + logger.error(f"[PrimusPath][Dataset] Failed to patch datasets.load_dataset: {e}") diff --git a/primus/modules/trainer/torchtitan/pre_trainer.py b/primus/modules/trainer/torchtitan/pre_trainer.py index 71c824806..4e29cf98d 100644 --- a/primus/modules/trainer/torchtitan/pre_trainer.py +++ b/primus/modules/trainer/torchtitan/pre_trainer.py @@ -16,6 +16,9 @@ def __init__(self, *args, **kwargs): extra_args = kwargs.pop("extra_args", None) super().__init__(*args, **kwargs) + # important: make sure patch torchtitan logger first + self.patch_torchtitan_logger() + self.primus_cfg = kwargs.pop("primus_config", None) if self.primus_cfg is None: raise ValueError("primus_config is required") @@ -23,8 +26,6 @@ def __init__(self, *args, **kwargs): pre_trainer_cfg = self.primus_cfg.get_module_config("pre_trainer") cfg_dict = nested_namespace_to_dict(pre_trainer_cfg) - self.patch_torchtitan_embedding_amp(cfg_dict["primus_turbo"]["enable_embedding_autocast"]) - patch_mock = getattr(pre_trainer_cfg.training, "mock_data", False) if patch_mock: from primus.modules.trainer.torchtitan.patch_utils import ( @@ -33,6 +34,9 @@ def __init__(self, *args, **kwargs): patch_mock_hf_dataset() + self.patch_torchtitan_embedding_amp(cfg_dict["primus_turbo"]["enable_embedding_autocast"]) + self.patch_titan_train_spec(pre_trainer_cfg.model.name, pre_trainer_cfg.model.flavor, extra_args) + # ensure checkpoint patch applied before import torchtitan # background: consolidate_safetensors_files_on_every_rank is a new DCP # utility introduced in newer torch versions. our current build does not @@ -56,17 +60,6 @@ def __init__(self, *args, **kwargs): # attention or training logic, so this patch does not affect behavior. self.patch_torch_flex_attention_auxoutput() - from primus.modules.trainer.torchtitan.patch_utils import ( - apply_patch_checkpoint_wrapper, - ) - - apply_patch_checkpoint_wrapper() - - self.patch_titan_train_spec(pre_trainer_cfg.model.name, pre_trainer_cfg.model.flavor, extra_args) - - # important: make sure patch torchtitan logger first - self.patch_torchtitan_logger() - from torchtitan.config.job_config import JobConfig from torchtitan.train import Trainer @@ -230,7 +223,7 @@ def __init__(self, **kwargs): setattr(self, k, v) setattr(flex_mod, "AuxOutput", _AuxOutput) - primus_logger.warning( + primus_logger.info( "[PrimusPatch][FlexAttn] Injected fallback AuxOutput stub (Titan does not rely on this)." ) @@ -256,15 +249,11 @@ def enable_primus_turbo_extension(self): if self.titan_config.primus_turbo.use_turbo_attention: # ******* llama3 Attention Model ******* - import torchtitan + import torchtitan.models.llama3.model.model - from primus.backends.torchtitan.models.deepseek_v3.model.model import ( - Attention, - ) from primus.backends.torchtitan.models.llama3.model.model import Attention torchtitan.models.llama3.model.model.Attention = Attention - torchtitan.models.deepseek_v3.model.model.Attention = Attention logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo Attention") if self.titan_config.primus_turbo.use_turbo_mx_linear: @@ -519,7 +508,7 @@ def patch_torchtitan_embedding_amp(self, enable_patch: bool): from primus.core.utils.logger import _logger as primus_logger if not enable_patch: - primus_logger.warning("[PrimusPatch][AMP] Embedding AMP patch disabled via config.") + primus_logger.info("[PrimusPatch][AMP] Embedding AMP patch disabled via config.") return def _hook(module, inp, out): @@ -528,9 +517,6 @@ def _hook(module, inp, out): if torch.is_autocast_enabled(): runtime_dtype = torch.get_autocast_gpu_dtype() - primus_logger.warning( - f"[PrimusPatch][AMP] Autocast active, casting Embedding output to runtime dtype {runtime_dtype}." - ) if out.dtype != runtime_dtype: return out.to(runtime_dtype) return out @@ -542,91 +528,55 @@ def new_init(self, *args, **kwargs): self.register_forward_hook(_hook) nn.Embedding.__init__ = new_init - primus_logger.warning( + primus_logger.info( "[PrimusPatch][AMP] nn.Embedding.__init__ patched for AMP/mixed precision alignment." ) def patch_titan_train_spec(self, model_name: str, flavor: str, model_overrides: Dict[str, Any]): """ Monkey patch torchtitan.train_spec.get_train_spec to override model args dynamically. - Supports nested overrides like: - {"model.moe_args.num_experts": 16, "model.moe_args.router.score_func": "softmax"} - - All override keys MUST start with "model.". + All override keys MUST start with "model." (e.g., {"model.n_layers": 8}). """ from primus.core.utils.logger import _logger as primus_logger if not model_overrides: - primus_logger.warning("[PrimusPatch][ModelOverride] No model_overrides provided, skip patch.") + primus_logger.info("[PrimusPatch][ModelOverride] No model_overrides provided, skip patch.") return - primus_logger.warning(f"[PrimusPatch][ModelOverride] Applying model_overrides: {model_overrides}") + primus_logger.info(f"[PrimusPatch][ModelOverride] Applying model_overrides: {model_overrides}") - # --- Step 1. Flatten any nested dict under 'model' + # --- flatten nested form {"model": {"n_layers": 4}} → {"model.n_layers": 4} flat_overrides = {} for k, v in model_overrides.items(): if k == "model" and isinstance(v, dict): - - def _flatten(prefix, d): - for subk, subv in d.items(): - if isinstance(subv, dict): - _flatten(f"{prefix}.{subk}", subv) - else: - flat_overrides[f"{prefix}.{subk}"] = subv - - _flatten("model", v) + for subk, subv in v.items(): + flat_overrides[f"model.{subk}"] = subv else: flat_overrides[k] = v model_overrides = flat_overrides # Enforce `model.` prefix strictly - bad_keys = [k for k in model_overrides if not k.startswith("model.")] + bad_keys = [k for k in model_overrides.keys() if not k.startswith("model.")] if bad_keys: raise ValueError( + # f"[PrimusPatch][ModelOverride] Unsupported override keys (must start with 'model.'): {bad_keys}" f"[PrimusPatch][ModelOverride] Invalid override keys detected: {bad_keys}. " "These parameters belong to the model configuration and must be specified " - "with the 'model.' prefix (e.g., 'model.n_layers' or 'model.moe_args.num_experts')." + "with the 'model.' prefix (e.g., 'model.n_layers', 'model.dim')." ) - primus_logger.warning(f"[PrimusPatch][ModelOverride] Applying overrides: {model_overrides}") + primus_logger.info( + f"[PrimusPatch][ModelOverride] model_overrides provided for '{model_name}' (flavor={flavor}): {model_overrides}" + ) import torchtitan.protocols.train_spec as train_spec_module orig_get_train_spec = train_spec_module.get_train_spec - def _deep_setattr(obj, attr_path: str, value: Any): - """ - Support setting nested attributes like "moe_args.num_experts" on dataclass or dict. - """ - parts = attr_path.split(".") - current = obj - for p in parts[:-1]: - if is_dataclass(current): - current = getattr(current, p) - elif isinstance(current, dict): - current = current[p] - else: - raise TypeError( - f"[PrimusPatch] Unsupported type in path traversal: {type(current)} at {p}" - ) - last_key = parts[-1] - if is_dataclass(current): - if not hasattr(current, last_key): - raise AttributeError( - f"[PrimusPatch] '{type(current).__name__}' has no field '{last_key}'" - ) - setattr(current, last_key, value) - elif isinstance(current, dict): - if last_key not in current: - raise KeyError(f"[PrimusPatch] dict has no key '{last_key}'") - current[last_key] = value - else: - raise TypeError(f"[PrimusPatch] Unsupported type for final assignment: {type(current)}") - def patched_get_train_spec(name: str): spec = orig_get_train_spec(name) if name != model_name: - return spec + return spec # only patch targeted model assert hasattr( spec, "model_args" @@ -648,19 +598,22 @@ def patched_get_train_spec(name: str): ), f"[PrimusPatch][ModelOverride] Expected dataclass model_args, got {type(target_args)}" before = asdict(target_args) + for k, v in model_overrides.items(): + field_name = k[len("model.") :] + if not hasattr(target_args, field_name): + raise AttributeError( + f"[PrimusPatch][ModelOverride] '{type(target_args).__name__}' has no field '{field_name}'" + ) + setattr(target_args, field_name, v) - for full_key, new_value in model_overrides.items(): - field_path = full_key[len("model.") :] - _deep_setattr(target_args, field_path, new_value) - - primus_logger.warning( - f"[PrimusPatch][ModelOverride] Successfully patched model_args['{flavor}'] for '{name}' with " - f"{model_overrides}. Diff(before→after): {before} → {asdict(target_args)}" + primus_logger.info( + f"[PrimusPatch][ModelOverride] Patched dataclass model_args['{flavor}'] " + f"for '{name}' with {model_overrides} (before={before})" ) return spec # Apply the patch globally train_spec_module.get_train_spec = patched_get_train_spec - primus_logger.warning( + primus_logger.info( f"[PrimusPatch][ModelOverride] get_train_spec for '{model_name}' successfully monkey patched (flavor={flavor})." ) diff --git a/tests/modules/trainer/torchtitan/test_patch_utils.py b/tests/modules/trainer/torchtitan/test_patch_utils.py index 7fe048a84..cf0aa9903 100644 --- a/tests/modules/trainer/torchtitan/test_patch_utils.py +++ b/tests/modules/trainer/torchtitan/test_patch_utils.py @@ -5,10 +5,7 @@ ############################################################################### -from primus.modules.trainer.torchtitan.patch_utils import ( - apply_patch_checkpoint_wrapper, - patch_mock_hf_dataset, -) +from primus.modules.trainer.torchtitan.patch_utils import patch_mock_hf_dataset from tests.utils import PrimusUT @@ -43,39 +40,3 @@ def test_mock_hf_dataset_patch(self): sample = ds[0] assert isinstance(sample["text"], str) assert len(sample["text"].split()) > 0 - - def test_patch_checkpoint_wrapper(self): - """ - Verify Primus patch for torch.distributed.algorithms._checkpoint.checkpoint_wrapper - correctly ignores unsupported kwargs (e.g., early_stop) - without breaking checkpoint functionality. - """ - import torch - - apply_patch_checkpoint_wrapper() - - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - ) - - class DummyModule(torch.nn.Module): - def forward(self, x): - return x + 1 - - m = DummyModule() - - # Should NOT raise: TypeError: unexpected keyword argument 'early_stop' - try: - wrapped = checkpoint_wrapper(m, preserve_rng_state=False, early_stop=True) - except TypeError as e: - raise AssertionError(f"checkpoint_wrapper should ignore unsupported kwargs but raised: {e}") - - assert isinstance(wrapped, torch.nn.Module) - - # Verify normal forward/backward still works - x = torch.tensor([2.0], requires_grad=True) - y = wrapped(x) - loss = y.sum() - loss.backward() - assert x.grad is not None - assert torch.allclose(x.grad, torch.tensor([1.0])) diff --git a/tests/trainer/test_torchtitan_trainer.py b/tests/trainer/test_torchtitan_trainer.py index ba0e3f744..4b5e5d50d 100644 --- a/tests/trainer/test_torchtitan_trainer.py +++ b/tests/trainer/test_torchtitan_trainer.py @@ -94,6 +94,14 @@ def test_llama3_1_8B_FP8(self): self.__class__.__name__, "llama3_8B-FP8", exp_path="examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml", + extra_args=["--model.n_layers", "4", "--training.steps", "3"], + ) + + def test_llama3_1_405B(self): + run_script( + self.__class__.__name__, + "llama3.1_405B", + "examples/torchtitan/configs/MI300X/llama3.1_405B-pretrain.yaml", extra_args=[ "--model.n_layers", "4", @@ -102,11 +110,11 @@ def test_llama3_1_8B_FP8(self): ], ) - def test_llama3_1_405B(self): + def test_llama3_1_70B_bf16(self): run_script( self.__class__.__name__, - "llama3.1_405B", - "examples/torchtitan/configs/MI300X/llama3.1_405B-pretrain.yaml", + "llama3.1_70B_bf16", + "examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml", extra_args=[ "--model.n_layers", "4", @@ -115,11 +123,11 @@ def test_llama3_1_405B(self): ], ) - def test_llama3_1_70B_BF16(self): + def test_llama3_1_70B_fp8(self): run_script( self.__class__.__name__, - "llama3.1_70B_BF16", - "examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml", + "llama3.1_70B_fp8", + "examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml", extra_args=[ "--model.n_layers", "4", @@ -128,11 +136,11 @@ def test_llama3_1_70B_BF16(self): ], ) - def test_llama3_1_70B_FP8(self): + def test_qwen3_0_6B(self): run_script( self.__class__.__name__, - "llama3.1_70B_FP8", - "examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml", + "qwen3_0.6B", + "examples/torchtitan/configs/MI300X/qwen3_0.6B-pretrain.yaml", extra_args=[ "--model.n_layers", "4", @@ -141,85 +149,62 @@ def test_llama3_1_70B_FP8(self): ], ) - # def test_qwen3_0_6B(self): - # run_script( - # self.__class__.__name__, - # "qwen3_0.6B", - # "examples/torchtitan/configs/MI300X/qwen3_0.6B-pretrain.yaml", - # extra_args=[ - # "--model.n_layers", - # "4", - # "--training.steps", - # "3", - # # "--primus_turbo.enable_primus_turbo", - # # "False", - # ], - # ) - - # def test_qwen3_1_7B(self): - # run_script( - # self.__class__.__name__, - # "qwen3_1.7B", - # "examples/torchtitan/configs/MI300X/qwen3_1.7B-pretrain.yaml", - # extra_args=[ - # "--model.n_layers", - # "4", - # "--training.steps", - # "3", - # # "--primus_turbo.enable_primus_turbo", - # # "False", - # ], - # ) - - # def test_qwen3_32B(self): - # run_script( - # self.__class__.__name__, - # "qwen3_32B", - # "examples/torchtitan/configs/MI300X/qwen3_32B-pretrain.yaml", - # extra_args=[ - # "--model.n_layers", - # "4", - # "--training.steps", - # "3", - # # "--primus_turbo.enable_primus_turbo", - # # "False", - # ], - # ) - - # def test_deepseek_v3_16b(self): - # run_script( - # self.__class__.__name__, - # "deepseek_v3_16b", - # "examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml", - # extra_args=[ - # "--model.n_layers", - # "4", - # "--model.n_dense_layers", - # "1", - # "--training.steps", - # "3", - # # "--primus_turbo.enable_primus_turbo", - # # "False", - # "--model.moe_args.use_grouped_mm", - # "False", - # ], - # ) - - # def test_deepseek_v3_671b(self): - # run_script( - # self.__class__.__name__, - # "deepseek_v3_671b", - # "examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml", - # extra_args=[ - # "--model.n_layers", - # "4", - # "--model.n_dense_layers", - # "1", - # "--training.steps", - # "3", - # # "--primus_turbo.enable_primus_turbo", - # # "False", - # "--model.moe_args.use_grouped_mm", - # "False", - # ], - # ) + def test_qwen3_1_7B(self): + run_script( + self.__class__.__name__, + "qwen3_1.7B", + "examples/torchtitan/configs/MI300X/qwen3_1.7B-pretrain.yaml", + extra_args=[ + "--model.n_layers", + "4", + "--training.steps", + "3", + ], + ) + + def test_qwen3_32B(self): + run_script( + self.__class__.__name__, + "qwen3_32B", + "examples/torchtitan/configs/MI300X/qwen3_32B-pretrain.yaml", + extra_args=[ + "--model.n_layers", + "4", + "--training.steps", + "3", + ], + ) + + def test_deepseek_v3_16b(self): + run_script( + self.__class__.__name__, + "deepseek_v3_16b", + "examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml", + extra_args=[ + "--model.n_layers", + "4", + "--model.n_dense_layers", + "1", + "--training.steps", + "3", + "--primus_turbo.enable_primus_turbo", + "False", + ], + ) + + def test_deepseek_v3_671b(self): + run_script( + self.__class__.__name__, + "deepseek_v3_671b", + "examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml", + extra_args=[ + "--model.n_layers", + "4", + "--model.n_dense_layers", + "1", + "--training.steps", + "3", + "--primus_turbo.enable_primus_turbo", + "False", + ], + ) diff --git a/third_party/torchtitan b/third_party/torchtitan index 99c0cb28f..5fb7cc2e3 160000 --- a/third_party/torchtitan +++ b/third_party/torchtitan @@ -1 +1 @@ -Subproject commit 99c0cb28f615d99290273afa1da01fd72f01f1a5 +Subproject commit 5fb7cc2e3bbb9b9dc0ab7af34ed5cc58b5f32021 From be3b984f1f0b267fa636ef3034fe7679dd1d778f Mon Sep 17 00:00:00 2001 From: Xiaoming-AMD Date: Mon, 17 Nov 2025 23:25:39 -0600 Subject: [PATCH 12/32] torchtitan: tune FP8 configs and share quant settings --- .../MI300X/llama3.1_8B-FP8-pretrain.yaml | 10 +-- .../MI355X/llama3.1_8B-FP8-pretrain.yaml | 10 +-- .../modules/torchtitan/pre_trainer.yaml | 29 ++++---- .../configs/modules/torchtitan/quantize.yaml | 66 +++++++++++++++++++ 4 files changed, 93 insertions(+), 22 deletions(-) create mode 100644 primus/configs/modules/torchtitan/quantize.yaml diff --git a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml index 76e116795..e3ddcff96 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml @@ -30,10 +30,12 @@ modules: mode: "none" # ["none", "selective", "full"] selective_ac_option: "op" # "int" = ac every positive int layer or 'op', ac based on ops policy - float8: - enable_fsdp_float8_all_gather: true - precompute_float8_dynamic_scale_for_fsdp: true - filter_fqns: ["output"] + quantize: + linear: + float8: + enable_fsdp_float8_all_gather: true + precompute_float8_dynamic_scale_for_fsdp: true + filter_fqns: ["output"] primus_turbo: enable_primus_turbo: true diff --git a/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml index 5b29aa254..33206cf80 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml @@ -30,10 +30,12 @@ modules: mode: "none" # ["none", "selective", "full"] selective_ac_option: "op" # "int" = ac every positive int layer or 'op', ac based on ops policy - float8: - enable_fsdp_float8_all_gather: true - precompute_float8_dynamic_scale_for_fsdp: true - filter_fqns: ["output"] + quantize: + linear: + float8: + enable_fsdp_float8_all_gather: true + precompute_float8_dynamic_scale_for_fsdp: true + filter_fqns: ["output"] primus_turbo: enable_primus_turbo: true diff --git a/primus/configs/modules/torchtitan/pre_trainer.yaml b/primus/configs/modules/torchtitan/pre_trainer.yaml index 74fb41cf7..873ced2df 100644 --- a/primus/configs/modules/torchtitan/pre_trainer.yaml +++ b/primus/configs/modules/torchtitan/pre_trainer.yaml @@ -1,5 +1,6 @@ includes: - ../module_base.yaml + - quantize.yaml lr_scheduler: decay_ratio: null @@ -102,20 +103,20 @@ compile: - model - loss -float8: - enable_fsdp_float8_all_gather: false - precompute_float8_dynamic_scale_for_fsdp: false - recipe_name: null - filter_fqns: [] - emulate: false - moe_fqns_prototype: [] - -mx: - mxfp8_dim1_cast_kernel_choice: triton - recipe_name: mxfp8_cublas - filter_fqns: - - output - moe_fqns_prototype: [] +# float8: +# enable_fsdp_float8_all_gather: false +# precompute_float8_dynamic_scale_for_fsdp: false +# recipe_name: null +# filter_fqns: [] +# emulate: false +# moe_fqns_prototype: [] + +# mx: +# mxfp8_dim1_cast_kernel_choice: triton +# recipe_name: mxfp8_cublas +# filter_fqns: +# - output +# moe_fqns_prototype: [] comm: init_timeout_seconds: 300 diff --git a/primus/configs/modules/torchtitan/quantize.yaml b/primus/configs/modules/torchtitan/quantize.yaml new file mode 100644 index 000000000..412321495 --- /dev/null +++ b/primus/configs/modules/torchtitan/quantize.yaml @@ -0,0 +1,66 @@ +# Quantize Configuration for TorchTitan +# This configuration controls quantized training for linear layers and grouped GEMMs + +quantize: + # Configuration for nn.Linear layers + linear: + # FP8 (Float8) training config for nn.Linear layers + float8: + # Whether to enable float8 all-gather in FSDP + # Recommended for tensorwise scaling + enable_fsdp_float8_all_gather: false + + # Whether to precompute float8 scales dynamically for FSDP + # Recommended for tensorwise scaling + precompute_float8_dynamic_scale_for_fsdp: false + + # Float8 recipe name: "tensorwise", "rowwise", or "rowwise_with_gw_hp" + # If specified, creates float8 config from recipe name + recipe_name: null + + # List of fully qualified names (FQNs) of modules to skip applying float8 training + # nn.Linear modules with any dim size not divisible by 16 are always skipped + # Example: ["attention.wq", "attention.wk", "attention.wv", "output"] + filter_fqns: [] + + # If true, use emulation instead of hardware accelerated gemm + # For test purposes only (CI without sm_89 capability) + # Not compatible with torch.compile + emulate: false + + # MX (Microscaling) training config for nn.Linear layers + mx: + # Kernel choice for mxfp8 dim1 cast: "triton", "cuda", or "torch" + # CUDA is recommended for best performance + mxfp8_dim1_cast_kernel_choice: "triton" + + # MX recipe name (default: "mxfp8_cublas") + # See: https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats + recipe_name: "mxfp8_cublas" + + # List of FQNs to skip applying mxfp8 training + # nn.Linear modules with any dim size not divisible by 16 are also skipped + # By default, the output layer is always skipped + # Example: ["attention.wq", "attention.wk", "attention.wv", "output"] + filter_fqns: ["output"] + + # Configuration for grouped GEMMs (typically for MoE models) + grouped_mm: + # FP8 training config for grouped GEMMs + float8: + # Prototype feature: List of FQNs of MoE Layers to apply FP8 dynamic quantization + # Performance optimization still in progress + # Requires torchao nightly build + # Example: ["experts"] + fqns: [] + + # MX training config for grouped GEMMs + mx: + # Quantization recipe name for grouped GEMMs + recipe_name: "mxfp8" + + # Prototype feature: List of FQNs of MoE modules to apply MXFP8 dynamic quantization + # Performance optimization still in progress + # Requires torchao nightly build + # Example: ["experts"] + fqns: [] From 01f745d1cee544eac875972cf68761636c93ab32 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 02:16:16 -0600 Subject: [PATCH 13/32] update torchtitan yaml --- .../configs/MI300X/llama3.1_70B-BF16-pretrain.yaml | 1 + .../configs/MI300X/llama3.1_70B-FP8-pretrain.yaml | 11 +++++++---- .../configs/MI300X/llama3.1_8B-BF16-pretrain.yaml | 1 + .../configs/MI300X/llama3.1_8B-FP8-pretrain.yaml | 1 + 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml index 2fafea9b8..4122e1d8f 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml @@ -24,6 +24,7 @@ modules: training: local_batch_size: 4 + seq_len: 8192 steps: 50 optimizer: diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml index b8fe637ef..e63149b55 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml @@ -24,6 +24,7 @@ modules: training: local_batch_size: 3 + seq_len: 8192 steps: 50 optimizer: @@ -32,10 +33,12 @@ modules: activation_checkpoint: mode: full - float8: - enable_fsdp_float8_all_gather: true - precompute_float8_dynamic_scale_for_fsdp: true - filter_fqns: ["output"] + quantize: + linear: + float8: + enable_fsdp_float8_all_gather: true + precompute_float8_dynamic_scale_for_fsdp: true + filter_fqns: ["output"] primus_turbo: enable_primus_turbo : true diff --git a/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml index de4e24621..eedf00413 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml @@ -25,6 +25,7 @@ modules: training: local_batch_size: 3 + seq_len: 8192 steps: 50 activation_checkpoint: diff --git a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml index e3ddcff96..6b80da075 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml @@ -24,6 +24,7 @@ modules: training: local_batch_size: 4 + seq_len: 8192 steps: 50 activation_checkpoint: From e226b05d308728ef378c304545330794b842cf8e Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 10:52:48 -0600 Subject: [PATCH 14/32] enable mla configs in DS models --- examples/megatron/configs/MI355X/deepseek_v2-pretrain.yaml | 2 ++ .../megatron/configs/MI355X/deepseek_v2_lite-pretrain.yaml | 4 ++-- examples/megatron/configs/MI355X/deepseek_v3-pretrain.yaml | 5 +---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/megatron/configs/MI355X/deepseek_v2-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2-pretrain.yaml index d0557152a..72b836df0 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2-pretrain.yaml @@ -68,6 +68,8 @@ modules: # pad 192/128 for deepseek attention fused_padded_mla_attention: false + multi_latent_attention: true + # ckpt finetune: false auto_continue_train: false diff --git a/examples/megatron/configs/MI355X/deepseek_v2_lite-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2_lite-pretrain.yaml index 70775504b..935f90dba 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2_lite-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2_lite-pretrain.yaml @@ -63,7 +63,7 @@ modules: # pad 192/128 for deepseek attention fused_padded_mla_attention: false - multi_latent_attention: false + #multi_latent_attention: true # ckpt finetune: false @@ -80,7 +80,7 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: true + enable_primus_turbo: false use_turbo_attention: false use_turbo_grouped_mlp: false diff --git a/examples/megatron/configs/MI355X/deepseek_v3-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v3-pretrain.yaml index 43a07f6e0..f86629857 100644 --- a/examples/megatron/configs/MI355X/deepseek_v3-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v3-pretrain.yaml @@ -63,9 +63,6 @@ modules: # pad 192/128 for deepseek attention fused_padded_mla_attention: false - # Performance toggles - #multi_latent_attention: false - #apply_rope_fusion: true # ckpt finetune: false @@ -82,7 +79,7 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: true + enable_primus_turbo: false use_turbo_attention: false use_turbo_grouped_mlp: false From 9f94561e2cf0005031329276850b11eda59c43fa Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 15:09:10 -0600 Subject: [PATCH 15/32] update fp8 llama3 70b tt yaml --- .../configs/MI355X/llama3.1_70B-FP8-pretrain.yaml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml index 6bb688d8d..94bd3770d 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml @@ -32,10 +32,12 @@ modules: activation_checkpoint: mode: full - float8: - enable_fsdp_float8_all_gather: true - precompute_float8_dynamic_scale_for_fsdp: true - filter_fqns: ["output"] + fquantize: + linear: + float8: + enable_fsdp_float8_all_gather: true + precompute_float8_dynamic_scale_for_fsdp: true + filter_fqns: ["output"] primus_turbo: enable_primus_turbo : true From 299322e48801ecc604b9744239cece93a6b65572 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 16:00:48 -0600 Subject: [PATCH 16/32] update torctitan config to use real dataset --- .../torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml | 1 + .../torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml | 1 + .../torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml | 1 + .../torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml | 1 + .../torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml | 2 ++ .../torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml | 4 +++- .../torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml | 2 ++ .../torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml | 2 ++ 8 files changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml index 4122e1d8f..68d62016a 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml @@ -25,6 +25,7 @@ modules: training: local_batch_size: 4 seq_len: 8192 + mock_data: false steps: 50 optimizer: diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml index e63149b55..8a70d2ce9 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml @@ -25,6 +25,7 @@ modules: training: local_batch_size: 3 seq_len: 8192 + mock_data: false steps: 50 optimizer: diff --git a/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml index eedf00413..19a8296db 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml @@ -26,6 +26,7 @@ modules: training: local_batch_size: 3 seq_len: 8192 + mock_data: false steps: 50 activation_checkpoint: diff --git a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml index 6b80da075..b564c605a 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml @@ -25,6 +25,7 @@ modules: training: local_batch_size: 4 seq_len: 8192 + mock_data: false steps: 50 activation_checkpoint: diff --git a/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml index 5abf45cc3..394866503 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml @@ -24,6 +24,8 @@ modules: training: local_batch_size: 4 + seq_len: 8192 + mock_data: false steps: 50 optimizer: diff --git a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml index 94bd3770d..ecaf9d392 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml @@ -24,6 +24,8 @@ modules: training: local_batch_size: 3 + seq_len: 8192 + mock_data: false steps: 50 optimizer: @@ -32,7 +34,7 @@ modules: activation_checkpoint: mode: full - fquantize: + quantize: linear: float8: enable_fsdp_float8_all_gather: true diff --git a/examples/torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml index 998dcde58..74a4eed3a 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml @@ -25,6 +25,8 @@ modules: training: local_batch_size: 3 + seq_len: 8192 + mock_data: false steps: 50 activation_checkpoint: diff --git a/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml index 33206cf80..b6c8e5020 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml @@ -24,6 +24,8 @@ modules: training: local_batch_size: 4 + seq_len: 8192 + mock_data: false steps: 50 activation_checkpoint: From eabc2f871a726e104450ca4a7a6501f6ae5c00c9 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 16:11:38 -0600 Subject: [PATCH 17/32] update mi300 ds model yamls with mla --- .../configs/MI300X/deepseek_v2_lite-pretrain.yaml | 8 ++++---- .../megatron/configs/MI300X/deepseek_v3-pretrain.yaml | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml index 472a4ef45..66d628837 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml @@ -63,7 +63,7 @@ modules: # pad 192/128 for deepseek attention fused_padded_mla_attention: false - multi_latent_attention: false + multi_latent_attention: true # ckpt finetune: false @@ -80,9 +80,9 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: true - use_turbo_attention: true - use_turbo_grouped_mlp: true + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false # fp8: e4m3 # fp8_recipe: blockwise # tensorwise, blockwise diff --git a/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml index 38befa2ea..242b07b09 100644 --- a/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml @@ -82,9 +82,9 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: true - use_turbo_attention: true - use_turbo_grouped_mlp: true + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false # MoE overlap # overlap_moe_expert_parallel_comm: true From 919f9f64e979be261b7728788d43e25cf5b9cd25 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 16:54:45 -0600 Subject: [PATCH 18/32] Revert "update mi300 ds model yamls with mla" This reverts commit eabc2f871a726e104450ca4a7a6501f6ae5c00c9. --- .../configs/MI300X/deepseek_v2_lite-pretrain.yaml | 8 ++++---- .../megatron/configs/MI300X/deepseek_v3-pretrain.yaml | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml index 66d628837..472a4ef45 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2_lite-pretrain.yaml @@ -63,7 +63,7 @@ modules: # pad 192/128 for deepseek attention fused_padded_mla_attention: false - multi_latent_attention: true + multi_latent_attention: false # ckpt finetune: false @@ -80,9 +80,9 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: false - use_turbo_attention: false - use_turbo_grouped_mlp: false + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: true # fp8: e4m3 # fp8_recipe: blockwise # tensorwise, blockwise diff --git a/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml index 242b07b09..38befa2ea 100644 --- a/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml @@ -82,9 +82,9 @@ modules: eval_iters: 0 # Turbo - enable_primus_turbo: false - use_turbo_attention: false - use_turbo_grouped_mlp: false + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: true # MoE overlap # overlap_moe_expert_parallel_comm: true From 554be01a0f7ef4ba2db6d238bb98906be402add2 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Wed, 19 Nov 2025 19:18:58 -0600 Subject: [PATCH 19/32] fix PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32 typo. It will now default to 0 --- examples/run_pretrain.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 92a0aa712..1ff42ef79 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -275,7 +275,7 @@ export NVTE_USE_OPTIMIZED_HIPIFIED_CAST_TRANSPOSE=${NVTE_USE_OPTIMIZED_HIPIFIED_ export NVTE_CK_USES_BWD_V3=${NVTE_CK_USES_BWD_V3:-0} # Note: Disable fp32 atomic due if you find any accuracy issue. -export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=${PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32:0} +export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=${PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32:-0} # nvte debug envs export NVTE_DEBUG=0 # 0, 1 From 653f0ec50d688300e697003df2e4a0d871a83e04 Mon Sep 17 00:00:00 2001 From: vidushi8 Date: Tue, 25 Nov 2025 11:49:10 -0600 Subject: [PATCH 20/32] update torch profiler gzip to false --- examples/megatron/configs/MI300X/llama3.1_8B-pretrain.yaml | 4 ++-- primus/configs/modules/megatron/primus_megatron_module.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/megatron/configs/MI300X/llama3.1_8B-pretrain.yaml b/examples/megatron/configs/MI300X/llama3.1_8B-pretrain.yaml index 6d615af16..4d49c5c09 100644 --- a/examples/megatron/configs/MI300X/llama3.1_8B-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3.1_8B-pretrain.yaml @@ -68,8 +68,8 @@ modules: # Turbo enable_primus_turbo: true - use_turbo_attention: true - use_turbo_grouped_mlp: true + use_turbo_attention: false + use_turbo_grouped_mlp: false # Cross entropy flags # cross_entropy_fusion_impl: "te" diff --git a/primus/configs/modules/megatron/primus_megatron_module.yaml b/primus/configs/modules/megatron/primus_megatron_module.yaml index a8cd951c1..0ec3a22b0 100644 --- a/primus/configs/modules/megatron/primus_megatron_module.yaml +++ b/primus/configs/modules/megatron/primus_megatron_module.yaml @@ -18,7 +18,7 @@ use_rocm_mem_info_iters: [1,2] disable_profiler_activity_cpu: false torch_profiler_record_shapes: true torch_profiler_with_stack: true -torch_profiler_use_gzip: true +torch_profiler_use_gzip: false # continue/finetune auto_continue_train: false From 343b5feb4473a83407e709e96a133e9d473f069c Mon Sep 17 00:00:00 2001 From: DCCS-4317 Date: Wed, 19 Nov 2025 23:21:19 +0000 Subject: [PATCH 21/32] support turbo groupgemm in titan --- .../MI300X/deepseek_v3_16b-pretrain.yaml | 19 ++-- primus/backends/torchtitan/models/moe/moe.py | 59 +++++++++++++ .../config_extension.py | 2 + .../modules/torchtitan/pre_trainer.yaml | 2 + .../modules/trainer/torchtitan/pre_trainer.py | 31 +++++++ run_titan_dsv2_lite.sh | 88 +++++++++++++++++++ 6 files changed, 195 insertions(+), 6 deletions(-) create mode 100644 primus/backends/torchtitan/models/moe/moe.py create mode 100755 run_titan_dsv2_lite.sh diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml index 95e1a80cf..9837c5607 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml @@ -20,7 +20,7 @@ modules: save_memory_snapshot_folder: "memory_snapshot" metrics: - log_freq: 10 + log_freq: 1 disable_color_printing: false enable_tensorboard: false save_tb_folder: "tb" @@ -41,8 +41,8 @@ modules: local_batch_size: 4 seq_len: 4096 max_norm: 1.0 # grad norm clipping - steps: 1000 - dataset: "c4" # supported datasets: c4_test (2K), c4 (177M) + steps: 10 + dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M) parallelism: data_parallel_replicate_degree: 1 @@ -64,13 +64,20 @@ modules: async_mode: "disabled" # ["disabled", "async", "async_with_pinned_mem"] activation_checkpoint: - mode: "none" # ["none", "selective", "full"] + mode: "selective" # ["none", "selective", "full"] selective_ac_option: "op" # 'int' = ac every positive int layer or 'op', ac based on ops policy compile: - enable: true - components: ["loss"] # ["model", "loss"] + enable: false + components: ["model", "loss"] # ["model", "loss"] + primus_turbo: + enable_primus_turbo: false + use_turbo_mx_linear: false + enable_attention_float8: false + use_turbo_grouped_mm: true + use_moe_fp8: true + # quantize: # linear: # float8: diff --git a/primus/backends/torchtitan/models/moe/moe.py b/primus/backends/torchtitan/models/moe/moe.py new file mode 100644 index 000000000..e661e997e --- /dev/null +++ b/primus/backends/torchtitan/models/moe/moe.py @@ -0,0 +1,59 @@ + + +import torch +import torch.nn.functional as F +import primus_turbo.pytorch as turbo +from primus_turbo.pytorch.core.float8 import ( + Float8QuantConfig, + Format, + ScalingGranularity, +) + + +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + use_deepep: bool = False, + use_fp8: bool = True, +) -> torch.Tensor: + assert x.dim() == 2 + num_tokens_per_expert = num_tokens_per_expert.to(torch.int64).to(x.device) + if use_fp8: + fp8_cfg = Float8QuantConfig( + format=Format.E4M3, + granularity=ScalingGranularity.TENSORWISE, # or ROWWISE ,TENSORWISE + ) + + h = F.silu( + turbo.ops.grouped_gemm_fp8( + x.bfloat16(), w1.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True, + config=fp8_cfg + ) + ) + h = h * turbo.ops.grouped_gemm_fp8( + x.bfloat16(), w3.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True, + config=fp8_cfg + ) + + out = turbo.ops.grouped_gemm_fp8( + h, w2.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True, + config=fp8_cfg + ).type_as(x) + else: + h = F.silu( + turbo.ops.grouped_gemm( + x.bfloat16(), w1.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True + ) + ) + h = h * turbo.ops.grouped_gemm( + x.bfloat16(), w3.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True + ) + + out = turbo.ops.grouped_gemm( + h, w2.bfloat16(), group_lens=num_tokens_per_expert, trans_b=True + ).type_as(x) + + return out diff --git a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py index c554744e0..30dce288c 100644 --- a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py +++ b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py @@ -22,6 +22,8 @@ class PrimusTurboConfig: use_turbo_attention: bool = False use_turbo_async_tp: bool = False use_turbo_mx_linear: bool = False + use_turbo_grouped_mm: bool = False + use_moe_fp8: bool = True enable_embedding_autocast: bool = True # float8_config: PrimusTurboFloat8Config = field(default_factory=PrimusTurboFloat8Config) diff --git a/primus/configs/modules/torchtitan/pre_trainer.yaml b/primus/configs/modules/torchtitan/pre_trainer.yaml index 873ced2df..bf849bc6d 100644 --- a/primus/configs/modules/torchtitan/pre_trainer.yaml +++ b/primus/configs/modules/torchtitan/pre_trainer.yaml @@ -156,4 +156,6 @@ primus_turbo: use_turbo_attention: true use_turbo_async_tp: true use_turbo_mx_linear: true + use_turbo_grouped_mm: false + use_moe_fp8: true enable_embedding_autocast: true diff --git a/primus/modules/trainer/torchtitan/pre_trainer.py b/primus/modules/trainer/torchtitan/pre_trainer.py index 4e29cf98d..c7dfefc1d 100644 --- a/primus/modules/trainer/torchtitan/pre_trainer.py +++ b/primus/modules/trainer/torchtitan/pre_trainer.py @@ -67,6 +67,11 @@ def __init__(self, *args, **kwargs): self.JobConfigClass = JobConfig self.titan_config = self.build_job_config(cfg_dict, self.JobConfigClass) + + # patch torchtitan moe + # background: we use turbo grouped mm for moe, so we need to patch the torchtitan moe + self.patch_torchtitan_moe() + self.log_config(self.titan_config) self.trainer = None @@ -93,6 +98,32 @@ def patch_torchtitan_logger(self): titan_logging.logger = primus_logger titan_logging.init_logger = lambda: None + + def patch_torchtitan_moe(self): + if not self.titan_config.primus_turbo.use_turbo_grouped_mm: + return + from primus.core.utils.logger import _logger as primus_logger + primus_logger.info("Monkey patch torchtitan moe...") + try: + import functools + import torchtitan.models.moe.moe + from primus.backends.torchtitan.models.moe.moe import _run_experts_grouped_mm + + # Get MoE FP8 configuration and create a partial function + use_moe_fp8 = self.titan_config.primus_turbo.use_moe_fp8 + primus_logger.info(f"Set MoE FP8 mode: {use_moe_fp8}") + + # Patch the grouped_mm function with use_fp8 parameter pre-set + torchtitan.models.moe.moe._run_experts_grouped_mm = functools.partial( + _run_experts_grouped_mm, use_fp8=use_moe_fp8 + ) + primus_logger.info("Successfully patched torchtitan moe with turbo grouped_mm") + except ImportError as e: + raise ImportError( + f"Failed to import primus_turbo for MoE grouped_mm patch. " + f"Please ensure primus_turbo is installed or set use_turbo_grouped_mm=False. " + f"Original error: {e}" + ) from e def patch_torch_dcp_consolidate(self): """ diff --git a/run_titan_dsv2_lite.sh b/run_titan_dsv2_lite.sh new file mode 100755 index 000000000..75421300a --- /dev/null +++ b/run_titan_dsv2_lite.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +export HF_TOKEN=${HF_TOKEN:-"your_hf_token"} +export USE_ROCM_AITER_ROPE_BACKEND=0 +export CLEAN_DOCKER_CONTAINER=0 + +export USING_AINIC=0 +export NCCL_IB_HCA="bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re7,bnxt_re8" +# export AINIC_LIB="/apps/gpuperf/ainic-driver-20251007/lib/" +export ANP_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/amd-anp-1.1.0-5" +export RCCL_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/rccl-drop-2025-08" +export NCCL_SOCKET_IFNAME="enp49s0f0np0" +export GLOO_SOCKET_IFNAME="enp49s0f0np0" + +export DOCKER_IMAGE="docker.io/rocm/pyt-megatron-lm-jax-nightly-private:primus_rocm7.1_20251117" +export CPUS_PER_TASK=128 +export HSA_NO_SCRATCH_RECLAIM=0 +export NVTE_CK_USES_BWD_V3=1 + +export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml" +mkdir -p data +# the real number of nodes to run +export NNODES=1 +MBS=8 +TP=1 +ETP=1 +GBS=$(($NNODES * 512)) +SEQ_LENGTH=4096 +PP=1 +EP=8 +CP=1 +VPP=1 +OPTIMIZER=adam +RECOMPUTE_LAYERS=0 +RECOMPUTE_ID_START=0 +BALANCE=True +LEGACY_GG=False +FP8=False + +CONFIG="titain-DSv2-Lite-FP8-$FP8.GBS$GBS.PP$PP.EP$EP.CP$CP.VPP$VPP.TOPK$TOPK.rc-$RECOMPUTE_LAYERS.rcids-$RECOMPUTE_ID_START.nodes$NNODES.$OPTIMIZER.BALANCE-$BALANCE-legacygg-$LEGACY_GG-noturboattn-noturbogg" +echo "config: $CONFIG" + +if [ $VPP -gt 1 ]; then + export VPP_CONFIG="--num_virtual_stages_per_pipeline_rank $VPP" +fi + +if [ "$FP8" = "True" ]; then + export FP8_CONFIG="--fp8 hybrid" +fi + +export PRIMUS_TEAM="date-new-$(date +%Y%m%d)" +export PRIMUS_USER=john +export PRIMUS_EXP_NAME=$CONFIG + + +LOG_DIR=./output/$PRIMUS_TEAM/$PRIMUS_USER/$PRIMUS_EXP_NAME/ +export DUMP_PP_DIR=$LOG_DIR/pp_dump/ +mkdir -p $LOG_DIR +LOG_FILE=$LOG_DIR/training.log +echo $LOG_FILE + +EXPORT_CONFIG=$LOG_DIR/config.yaml + +bash ./examples/run_slurm_pretrain.sh 2>&1 | tee $LOG_FILE + +# bash ./examples/run_slurm_pretrain.sh --micro_batch_size $MBS \ +# --global_batch_size $GBS \ +# --tensor_model_parallel_size $TP \ +# --expert_tensor_parallel_size $ETP \ +# --pipeline_model_parallel_size $PP \ +# --seq_length $SEQ_LENGTH \ +# --expert_model_parallel_size $EP \ +# --context_parallel_size $CP \ +# --moe_router_force_load_balancing $BALANCE \ +# --optimizer $OPTIMIZER \ +# --cp_comm_type a2a \ +# --recompute_num_layers $RECOMPUTE_LAYERS \ +# --moe_use_legacy_grouped_gemm $LEGACY_GG \ +# ${VPP_CONFIG} \ +# ${FP8_CONFIG} \ +# --profile True \ +# --disable_profiler_activity_cpu True \ +# --use_pytorch_profiler True \ +# --profile_step_start 5 \ +# --profile_step_end 6 \ +# --train_iters 10 2>&1 | tee $LOG_FILE + + From 04aa8ac4027a522303c6d174b00976a5d5d690fc Mon Sep 17 00:00:00 2001 From: liyingli Date: Thu, 20 Nov 2025 09:38:25 +0000 Subject: [PATCH 22/32] add turbo fp8 gemm and attn via Converter --- .../MI300X/deepseek_v3_16b-pretrain.yaml | 3 +- .../components/quantization/float8.py | 77 +++++++++++++++++++ .../torchtitan/components/quantization/mx.py | 4 +- .../models/deepseek_v3/model/__init__.py | 0 .../models/deepseek_v3/model/model.py | 75 ++++++++++++++++++ .../config_extension.py | 1 + .../models/torchtitan/deepseek_v3_16b.yaml | 1 + .../modules/torchtitan/pre_trainer.yaml | 1 + .../modules/trainer/torchtitan/pre_trainer.py | 24 +++++- 9 files changed, 182 insertions(+), 4 deletions(-) create mode 100644 primus/backends/torchtitan/components/quantization/float8.py create mode 100644 primus/backends/torchtitan/models/deepseek_v3/model/__init__.py create mode 100644 primus/backends/torchtitan/models/deepseek_v3/model/model.py diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml index 9837c5607..14b3bb8c6 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml @@ -72,8 +72,9 @@ modules: components: ["model", "loss"] # ["model", "loss"] primus_turbo: - enable_primus_turbo: false + enable_primus_turbo: true use_turbo_mx_linear: false + use_turbo_float8_linear: true enable_attention_float8: false use_turbo_grouped_mm: true use_moe_fp8: true diff --git a/primus/backends/torchtitan/components/quantization/float8.py b/primus/backends/torchtitan/components/quantization/float8.py new file mode 100644 index 000000000..e11e848ed --- /dev/null +++ b/primus/backends/torchtitan/components/quantization/float8.py @@ -0,0 +1,77 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import torch +import torch.nn as nn +from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity +from primus_turbo.pytorch.modules.linear_fp8 import Float8Linear +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.protocols.model_converter import ( + ModelConverter, + register_model_converter, +) +from torchtitan.tools.logging import logger + + +import torch.nn as nn + + +def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool: + """ + Filter function to determine which modules should be converted. + For both Float8 and MXFP8, we only convert Linear modules + with dimensions divisible by 16 and not matching any filtered FQNs. + """ + if not isinstance(mod, nn.Linear): + return False + + # All dims must be divisible by 16 due to float8 tensorcore hardware requirements. + dims_multiples_of_128 = ( + mod.weight.shape[0] % 128 == 0 and mod.weight.shape[1] % 128 == 0 + ) + + # If the fqn matches any filtered fqn, then we should not convert this module. + is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns) + + return dims_multiples_of_128 and not is_filtered_fqn + + +def replace_turbo_fp8linear_modules(model: nn.Module, config: Float8QuantConfig): + filter_fqns = ["gate", "output"] + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear) and not isinstance(module, Float8Linear): + if module_filter_fn(module, name, filter_fqns): + fp8_linear = Float8Linear.from_float(module, config) + logger.info(f"module {name} shape {module.weight.shape}, replaced to FP8Linear") + setattr(model, name, fp8_linear) + else: + logger.info(f"module {name} cannot be replaced to FP8Linear") + else: + replace_turbo_fp8linear_modules(module, config) + + +class PrimusTubroFP8Converter(ModelConverter): + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.enabled = True + self.config = Float8QuantConfig(granularity=ScalingGranularity.TENSORWISE) + + def convert(self, model: nn.Module): + if not self.enabled: + return + + replace_turbo_fp8linear_modules(model, self.config) + + logger.info("Swapped to FP8Linear layers") + + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): + """ + FP8 doesn't require any post-optimizer hooks at the moment + """ + return + + +register_model_converter(PrimusTubroFP8Converter, "primus_turbo_fp8") diff --git a/primus/backends/torchtitan/components/quantization/mx.py b/primus/backends/torchtitan/components/quantization/mx.py index cf6a83171..705bc47ea 100644 --- a/primus/backends/torchtitan/components/quantization/mx.py +++ b/primus/backends/torchtitan/components/quantization/mx.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity -from primus_turbo.pytorch.modules import Float8Linear +from primus_turbo.pytorch.modules.linear_fp8 import Float8Linear from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.protocols.model_converter import ( @@ -31,7 +31,7 @@ def replace_turbo_mxlinear_modules(model: nn.Module, config: Float8QuantConfig): class PrimusTubroMXConverter(ModelConverter): def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True - self.config = Float8QuantConfig(ScalingGranularity.BLOCKWISE, block_size=SCALING_BLOCK_SIZE) + self.config = Float8QuantConfig(granularity=ScalingGranularity.BLOCKWISE, block_size=SCALING_BLOCK_SIZE) def convert(self, model: nn.Module): if not self.enabled: diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/__init__.py b/primus/backends/torchtitan/models/deepseek_v3/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/model.py b/primus/backends/torchtitan/models/deepseek_v3/model/model.py new file mode 100644 index 000000000..0e9e6ac16 --- /dev/null +++ b/primus/backends/torchtitan/models/deepseek_v3/model/model.py @@ -0,0 +1,75 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import torch +from torch.nn.attention.flex_attention import BlockMask +from torchtitan.models.deepseek_v3.model.model import Attention as TTAttention +from torchtitan.models.deepseek_v3.model.model import apply_rotary_emb + +AttentionMasksType = dict[str, BlockMask] | BlockMask + + +class Attention(TTAttention): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + else: + q = self.wq_a(x) + q = self.wq_b(self.q_norm(q)) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of q and kv as TP may have sharded them after + # the above linear ops. + q = q.view(bsz, seqlen, -1, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) + + # Key-value projection + kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pe = apply_rotary_emb( + k_pe.unsqueeze(2), freqs_cis + ) # (bsz, seqlen, 1, qk_rope_head_dim) + + kv = self.wkv_b( + self.kv_norm(kv) + ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + ) # (bsz, seqlen, n_heads, qk_head_dim) + + # q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + # k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + # v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + + + output = self.inner_attention(q, k, v) + + output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) + return self.wo(output) # (bsz, seqlen, dim) diff --git a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py index 30dce288c..a326adb8c 100644 --- a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py +++ b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py @@ -22,6 +22,7 @@ class PrimusTurboConfig: use_turbo_attention: bool = False use_turbo_async_tp: bool = False use_turbo_mx_linear: bool = False + use_turbo_float8_linear: bool = False use_turbo_grouped_mm: bool = False use_moe_fp8: bool = True enable_embedding_autocast: bool = True diff --git a/primus/configs/models/torchtitan/deepseek_v3_16b.yaml b/primus/configs/models/torchtitan/deepseek_v3_16b.yaml index da6806120..d02d921a5 100644 --- a/primus/configs/models/torchtitan/deepseek_v3_16b.yaml +++ b/primus/configs/models/torchtitan/deepseek_v3_16b.yaml @@ -8,4 +8,5 @@ model: name: "deepseek_v3" flavor: "16B" hf_assets_path: "deepseek-ai/deepseek-moe-16b-base" + converters: ["turbo_fp8_linear"] # converters: ["float8"] diff --git a/primus/configs/modules/torchtitan/pre_trainer.yaml b/primus/configs/modules/torchtitan/pre_trainer.yaml index bf849bc6d..2fd50fe29 100644 --- a/primus/configs/modules/torchtitan/pre_trainer.yaml +++ b/primus/configs/modules/torchtitan/pre_trainer.yaml @@ -156,6 +156,7 @@ primus_turbo: use_turbo_attention: true use_turbo_async_tp: true use_turbo_mx_linear: true + use_turbo_float8_linear: true use_turbo_grouped_mm: false use_moe_fp8: true enable_embedding_autocast: true diff --git a/primus/modules/trainer/torchtitan/pre_trainer.py b/primus/modules/trainer/torchtitan/pre_trainer.py index c7dfefc1d..a09999cdf 100644 --- a/primus/modules/trainer/torchtitan/pre_trainer.py +++ b/primus/modules/trainer/torchtitan/pre_trainer.py @@ -285,6 +285,13 @@ def enable_primus_turbo_extension(self): from primus.backends.torchtitan.models.llama3.model.model import Attention torchtitan.models.llama3.model.model.Attention = Attention + + # ******* deepseek_v3 Attention Model ******* + import torchtitan.models.deepseek_v3.model.model + + from primus.backends.torchtitan.models.deepseek_v3.model.model import Attention + + torchtitan.models.deepseek_v3.model.model.Attention = Attention logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo Attention") if self.titan_config.primus_turbo.use_turbo_mx_linear: @@ -299,9 +306,24 @@ def enable_primus_turbo_extension(self): ) _registry_model_converter_cls["mx"] = PrimusTubroMXConverter - torchtitan.components.quantization.mx.MXConverter = PrimusTubroMXConverter + torchtitan.components.quantization.mx.MXLinearConverter = PrimusTubroMXConverter logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo MXLinear") + if self.titan_config.primus_turbo.use_turbo_float8_linear: + # ******* FP8Linear ******* + import torchtitan.components.quantization.float8 + from torchtitan.protocols.model_converter import ( + _registry_model_converter_cls, + ) + + from primus.backends.torchtitan.components.quantization.float8 import ( + PrimusTubroFP8Converter, + ) + + _registry_model_converter_cls["turbo_fp8_linear"] = PrimusTubroFP8Converter + torchtitan.components.quantization.float8.Float8LinearConverter = PrimusTubroFP8Converter + logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo FP8Linear") + if self.titan_config.primus_turbo.use_turbo_async_tp: # ******* Async TP ******* self.patch_torch_async_tp() From 80f247a873470a405d9d3d5f1be742e5c47c8aff Mon Sep 17 00:00:00 2001 From: liyingli Date: Fri, 21 Nov 2025 12:57:34 +0000 Subject: [PATCH 23/32] update deepseek_v3 config and load balance config --- .../MI300X/deepseek_v3_16b-pretrain.yaml | 11 +++--- .../MI300X/deepseek_v3_671b-pretrain.yaml | 23 ++++++++---- .../components/quantization/float8.py | 2 +- .../models/torchtitan/deepseek_v3_671b.yaml | 1 + .../modules/torchtitan/pre_trainer.yaml | 1 + run_titan_dsv2_lite.sh | 37 ++++--------------- 6 files changed, 33 insertions(+), 42 deletions(-) diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml index 14b3bb8c6..75e1a370d 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml @@ -13,7 +13,7 @@ modules: model: deepseek_v3_16b.yaml overrides: profiling: - enable_profiling: false + enable_profiling: true save_traces_folder: "profile_trace" profile_freq: 10 enable_memory_snapshot: false @@ -38,10 +38,11 @@ modules: min_lr_factor: 0.1 training: + debug_moe_force_load_balance: true local_batch_size: 4 seq_len: 4096 max_norm: 1.0 # grad norm clipping - steps: 10 + steps: 15 dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M) parallelism: @@ -64,11 +65,11 @@ modules: async_mode: "disabled" # ["disabled", "async", "async_with_pinned_mem"] activation_checkpoint: - mode: "selective" # ["none", "selective", "full"] + mode: "none" # ["none", "selective", "full"] selective_ac_option: "op" # 'int' = ac every positive int layer or 'op', ac based on ops policy compile: - enable: false + enable: true components: ["model", "loss"] # ["model", "loss"] primus_turbo: @@ -77,7 +78,7 @@ modules: use_turbo_float8_linear: true enable_attention_float8: false use_turbo_grouped_mm: true - use_moe_fp8: true + use_moe_fp8: false # quantize: # linear: diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml index b340538be..1d183f3ee 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml @@ -13,14 +13,14 @@ modules: model: deepseek_v3_671b.yaml overrides: profiling: - enable_profiling: false + enable_profiling: true save_traces_folder: "profile_trace" profile_freq: 10 enable_memory_snapshot: false save_memory_snapshot_folder: "memory_snapshot" metrics: - log_freq: 10 + log_freq: 1 disable_color_printing: false enable_tensorboard: false save_tb_folder: "tb" @@ -38,11 +38,12 @@ modules: min_lr_factor: 0.1 training: - local_batch_size: 4 + debug_moe_force_load_balance: true + local_batch_size: 14 seq_len: 4096 max_norm: 1.0 # grad norm clipping - steps: 1000 - dataset: "c4" # supported datasets: c4_test (2K), c4 (177M) + steps: 15 + dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M) parallelism: data_parallel_replicate_degree: 1 @@ -52,7 +53,7 @@ modules: enable_async_tensor_parallel: false pipeline_parallel_degree: 1 pipeline_parallel_schedule: "Interleaved1F1B" - expert_parallel_degree: 1 + expert_parallel_degree: 8 expert_tensor_parallel_degree: 1 checkpoint: @@ -69,7 +70,15 @@ modules: compile: enable: true - components: ["loss"] # ["model", "loss"] + components: ["model", "loss"] # ["model", "loss"] + + primus_turbo: + enable_primus_turbo: true + use_turbo_mx_linear: false + use_turbo_float8_linear: false + enable_attention_float8: true + use_turbo_grouped_mm: true + use_moe_fp8: false # quantize: # linear: diff --git a/primus/backends/torchtitan/components/quantization/float8.py b/primus/backends/torchtitan/components/quantization/float8.py index e11e848ed..28defbaa6 100644 --- a/primus/backends/torchtitan/components/quantization/float8.py +++ b/primus/backends/torchtitan/components/quantization/float8.py @@ -49,7 +49,7 @@ def replace_turbo_fp8linear_modules(model: nn.Module, config: Float8QuantConfig) logger.info(f"module {name} shape {module.weight.shape}, replaced to FP8Linear") setattr(model, name, fp8_linear) else: - logger.info(f"module {name} cannot be replaced to FP8Linear") + logger.info(f"module {name} shape {module.weight.shape}, cannot be replaced to FP8Linear") else: replace_turbo_fp8linear_modules(module, config) diff --git a/primus/configs/models/torchtitan/deepseek_v3_671b.yaml b/primus/configs/models/torchtitan/deepseek_v3_671b.yaml index e96eeb30e..a5fbe2125 100644 --- a/primus/configs/models/torchtitan/deepseek_v3_671b.yaml +++ b/primus/configs/models/torchtitan/deepseek_v3_671b.yaml @@ -7,4 +7,5 @@ model: name: "deepseek_v3" flavor: "671B" hf_assets_path: "deepseek-ai/DeepSeek-V3.1-Base" + # converters: ["turbo_fp8_linear"] # converters: ["float8"] diff --git a/primus/configs/modules/torchtitan/pre_trainer.yaml b/primus/configs/modules/torchtitan/pre_trainer.yaml index 2fd50fe29..5c16f4d5a 100644 --- a/primus/configs/modules/torchtitan/pre_trainer.yaml +++ b/primus/configs/modules/torchtitan/pre_trainer.yaml @@ -35,6 +35,7 @@ profiling: training: mock_data: true + debug_moe_force_load_balance: false dataset: c4 dataset_path: null deterministic: false diff --git a/run_titan_dsv2_lite.sh b/run_titan_dsv2_lite.sh index 75421300a..417e32b6a 100755 --- a/run_titan_dsv2_lite.sh +++ b/run_titan_dsv2_lite.sh @@ -9,15 +9,15 @@ export NCCL_IB_HCA="bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_r # export AINIC_LIB="/apps/gpuperf/ainic-driver-20251007/lib/" export ANP_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/amd-anp-1.1.0-5" export RCCL_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/rccl-drop-2025-08" -export NCCL_SOCKET_IFNAME="enp49s0f0np0" -export GLOO_SOCKET_IFNAME="enp49s0f0np0" +export NCCL_SOCKET_IFNAME="lo" +export GLOO_SOCKET_IFNAME="lo" export DOCKER_IMAGE="docker.io/rocm/pyt-megatron-lm-jax-nightly-private:primus_rocm7.1_20251117" export CPUS_PER_TASK=128 -export HSA_NO_SCRATCH_RECLAIM=0 +export HSA_NO_SCRATCH_RECLAIM=1 export NVTE_CK_USES_BWD_V3=1 -export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml" +# export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml" mkdir -p data # the real number of nodes to run export NNODES=1 @@ -49,7 +49,7 @@ if [ "$FP8" = "True" ]; then fi export PRIMUS_TEAM="date-new-$(date +%Y%m%d)" -export PRIMUS_USER=john +export PRIMUS_USER=liying export PRIMUS_EXP_NAME=$CONFIG @@ -61,28 +61,7 @@ echo $LOG_FILE EXPORT_CONFIG=$LOG_DIR/config.yaml -bash ./examples/run_slurm_pretrain.sh 2>&1 | tee $LOG_FILE - -# bash ./examples/run_slurm_pretrain.sh --micro_batch_size $MBS \ -# --global_batch_size $GBS \ -# --tensor_model_parallel_size $TP \ -# --expert_tensor_parallel_size $ETP \ -# --pipeline_model_parallel_size $PP \ -# --seq_length $SEQ_LENGTH \ -# --expert_model_parallel_size $EP \ -# --context_parallel_size $CP \ -# --moe_router_force_load_balancing $BALANCE \ -# --optimizer $OPTIMIZER \ -# --cp_comm_type a2a \ -# --recompute_num_layers $RECOMPUTE_LAYERS \ -# --moe_use_legacy_grouped_gemm $LEGACY_GG \ -# ${VPP_CONFIG} \ -# ${FP8_CONFIG} \ -# --profile True \ -# --disable_profiler_activity_cpu True \ -# --use_pytorch_profiler True \ -# --profile_step_start 5 \ -# --profile_step_end 6 \ -# --train_iters 10 2>&1 | tee $LOG_FILE - +# bash ./examples/run_pretrain.sh 2>&1 | tee $LOG_FILE +export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml" +bash ./examples/run_pretrain.sh --model.n_layers 4 --model.n_dense_layers 1 2>&1 | tee $LOG_FILE From d48cd84af3e3210cb2d09c7bfb597ffa01f5939c Mon Sep 17 00:00:00 2001 From: liyingli Date: Fri, 21 Nov 2025 14:54:27 +0000 Subject: [PATCH 24/32] add classic attn but with issues --- .../MI300X/deepseek_v3_671b-pretrain.yaml | 5 +- .../torchtitan/models/deepseek_v3/__init__.py | 101 ++++++++++++++++++ .../models/deepseek_v3/model/args.py | 20 ++++ .../models/deepseek_v3/model/model.py | 66 ++++++++++-- .../torchtitan/models/llama4/__init__.py | 0 .../models/llama4/model/__init__.py | 0 .../torchtitan/models/llama4/model/model.py | 42 ++++++++ .../config_extension.py | 1 + .../modules/torchtitan/pre_trainer.yaml | 1 + .../modules/trainer/torchtitan/pre_trainer.py | 57 ++++++++-- 10 files changed, 274 insertions(+), 19 deletions(-) create mode 100644 primus/backends/torchtitan/models/deepseek_v3/__init__.py create mode 100644 primus/backends/torchtitan/models/deepseek_v3/model/args.py create mode 100644 primus/backends/torchtitan/models/llama4/__init__.py create mode 100644 primus/backends/torchtitan/models/llama4/model/__init__.py create mode 100644 primus/backends/torchtitan/models/llama4/model/model.py diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml index 1d183f3ee..2025a01ed 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml @@ -69,14 +69,15 @@ modules: selective_ac_option: "op" # 'int' = ac every positive int layer or 'op', ac based on ops policy compile: - enable: true + enable: false components: ["model", "loss"] # ["model", "loss"] primus_turbo: enable_primus_turbo: true use_turbo_mx_linear: false use_turbo_float8_linear: false - enable_attention_float8: true + enable_attention_float8: false + use_classic_attention: true use_turbo_grouped_mm: true use_moe_fp8: false diff --git a/primus/backends/torchtitan/models/deepseek_v3/__init__.py b/primus/backends/torchtitan/models/deepseek_v3/__init__.py new file mode 100644 index 000000000..095917b9d --- /dev/null +++ b/primus/backends/torchtitan/models/deepseek_v3/__init__.py @@ -0,0 +1,101 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +from torchtitan.models.deepseek_v3 import deepseekv3_args +from torchtitan.models.moe import MoEArgs + +from .model.args import DeepSeekV3ClassicModelArgs + +classic_deepseekv3_args = { + **deepseekv3_args, + "16B": DeepSeekV3ClassicModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + n_kv_heads=16, + head_dim=2048 // 128, + ), + "236B": DeepSeekV3ClassicModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + moe_args=MoEArgs( + num_experts=160, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + route_scale=16.0, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=3, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + q_head=40, + n_kv_heads=8, + head_dim=128, + ), + "671B": DeepSeekV3ClassicModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=24, + n_dense_layers=0, + n_heads=128, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.5, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=4, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + q_head=40, + n_kv_heads=128, + head_dim=16, + ), +} diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/args.py b/primus/backends/torchtitan/models/deepseek_v3/model/args.py new file mode 100644 index 000000000..526bcefb8 --- /dev/null +++ b/primus/backends/torchtitan/models/deepseek_v3/model/args.py @@ -0,0 +1,20 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + + +from dataclasses import dataclass + +from torchtitan.models.deepseek_v3 import DeepSeekV3ModelArgs + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class DeepSeekV3ClassicModelArgs(DeepSeekV3ModelArgs): + # Classical Attention + n_heads: int = 128 + q_head: int = 40 + n_kv_heads: int = 8 + head_dim: int = 2048 // 128 diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/model.py b/primus/backends/torchtitan/models/deepseek_v3/model/model.py index 0e9e6ac16..a60998c67 100644 --- a/primus/backends/torchtitan/models/deepseek_v3/model/model.py +++ b/primus/backends/torchtitan/models/deepseek_v3/model/model.py @@ -9,6 +9,14 @@ from torchtitan.models.deepseek_v3.model.model import Attention as TTAttention from torchtitan.models.deepseek_v3.model.model import apply_rotary_emb +# Import the Attention class from llama4 as new option for DeepSeekV3 +from torchtitan.models.llama4.model.model import Attention as Llama4Attention +from torchtitan.models.llama4.model.model import ( + precompute_freqs_cis as llama4_precompute_freqs_cis, +) + +from .args import DeepSeekV3ClassicModelArgs + AttentionMasksType = dict[str, BlockMask] | BlockMask @@ -41,9 +49,7 @@ def forward( # local heads from sizes of q and kv as TP may have sharded them after # the above linear ops. q = q.view(bsz, seqlen, -1, self.qk_head_dim) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_pe = apply_rotary_emb(q_pe, freqs_cis) q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) @@ -51,13 +57,9 @@ def forward( kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = apply_rotary_emb( - k_pe.unsqueeze(2), freqs_cis - ) # (bsz, seqlen, 1, qk_rope_head_dim) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # (bsz, seqlen, 1, qk_rope_head_dim) - kv = self.wkv_b( - self.kv_norm(kv) - ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = self.wkv_b(self.kv_norm(kv)) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat( @@ -68,8 +70,52 @@ def forward( # k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) # v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - output = self.inner_attention(q, k, v) output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) return self.wo(output) # (bsz, seqlen, dim) + + +class MultiHeadAttention(Llama4Attention): + """ + Multi-head attention module for DeepSeekV3, inheriting from llama4's Attention class. + + This class adapts the llama4 Attention class to work with DeepSeekV3ModelArgs + instead of TransformerModelArgs. + """ + + def __init__( + self, + model_args: DeepSeekV3ClassicModelArgs, + use_rope: bool = True, + fixed_block_size: int | None = None, + ): + # Convert DeepSeekV3ModelArgs to a format compatible with llama4's Attention + # Create a mock TransformerModelArgs-like object + class MockTransformerModelArgs: + def __init__(self, deepseek_args: DeepSeekV3ClassicModelArgs): + self.n_heads = deepseek_args.q_head + self.n_kv_heads = deepseek_args.n_kv_heads + self.dim = deepseek_args.dim + self.head_dim = deepseek_args.head_dim + + self.use_flex_attn = deepseek_args.use_flex_attn + + # Initialize the parent class with the mock args + super().__init__( + MockTransformerModelArgs(model_args), use_rope=use_rope, fixed_block_size=fixed_block_size + ) + self.rope_theta = model_args.rope_theta + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + # Always use llama4-style freqs_cis for this attention, regardless of input + seqlen = x.shape[1] + freqs_llama4 = llama4_precompute_freqs_cis(self.head_dim, seqlen, self.rope_theta) + # Ensure freqs are on the same device as activations + freqs_llama4 = freqs_llama4.to(x.device, dtype=x.dtype) + return super().forward(x, freqs_llama4, None) diff --git a/primus/backends/torchtitan/models/llama4/__init__.py b/primus/backends/torchtitan/models/llama4/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/backends/torchtitan/models/llama4/model/__init__.py b/primus/backends/torchtitan/models/llama4/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus/backends/torchtitan/models/llama4/model/model.py b/primus/backends/torchtitan/models/llama4/model/model.py new file mode 100644 index 000000000..8233f13ee --- /dev/null +++ b/primus/backends/torchtitan/models/llama4/model/model.py @@ -0,0 +1,42 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +import torch +from torch.nn.attention.flex_attention import BlockMask +from torchtitan.models.llama4.model.model import Attention as TTAttention +from torchtitan.models.llama4.model.model import apply_rotary_emb + +AttentionMasksType = dict[str, BlockMask] | BlockMask + + +class Attention(TTAttention): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + if self.use_rope: + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + # xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + # xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + output = self.inner_attention(xq, xk, xv) + + output = output.view(bs, seqlen, -1) + return self.wo(output) diff --git a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py index a326adb8c..9ac195dd1 100644 --- a/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py +++ b/primus/backends/torchtitan/primus_turbo_extensions/config_extension.py @@ -26,6 +26,7 @@ class PrimusTurboConfig: use_turbo_grouped_mm: bool = False use_moe_fp8: bool = True enable_embedding_autocast: bool = True + use_classic_attention: bool = False # float8_config: PrimusTurboFloat8Config = field(default_factory=PrimusTurboFloat8Config) diff --git a/primus/configs/modules/torchtitan/pre_trainer.yaml b/primus/configs/modules/torchtitan/pre_trainer.yaml index 5c16f4d5a..693ccc13e 100644 --- a/primus/configs/modules/torchtitan/pre_trainer.yaml +++ b/primus/configs/modules/torchtitan/pre_trainer.yaml @@ -155,6 +155,7 @@ primus_turbo: enable_primus_turbo : true enable_attention_float8 : false use_turbo_attention: true + use_classic_attention: false use_turbo_async_tp: true use_turbo_mx_linear: true use_turbo_float8_linear: true diff --git a/primus/modules/trainer/torchtitan/pre_trainer.py b/primus/modules/trainer/torchtitan/pre_trainer.py index a09999cdf..490a14b35 100644 --- a/primus/modules/trainer/torchtitan/pre_trainer.py +++ b/primus/modules/trainer/torchtitan/pre_trainer.py @@ -67,17 +67,19 @@ def __init__(self, *args, **kwargs): self.JobConfigClass = JobConfig self.titan_config = self.build_job_config(cfg_dict, self.JobConfigClass) - + # patch torchtitan moe # background: we use turbo grouped mm for moe, so we need to patch the torchtitan moe self.patch_torchtitan_moe() - + self.log_config(self.titan_config) self.trainer = None if hasattr(self.titan_config, "primus_turbo") and self.titan_config.primus_turbo.enable_primus_turbo: self.enable_primus_turbo_extension() + self.patch_classic_attention() + def setup(self): pass @@ -98,21 +100,26 @@ def patch_torchtitan_logger(self): titan_logging.logger = primus_logger titan_logging.init_logger = lambda: None - + def patch_torchtitan_moe(self): if not self.titan_config.primus_turbo.use_turbo_grouped_mm: return from primus.core.utils.logger import _logger as primus_logger + primus_logger.info("Monkey patch torchtitan moe...") try: import functools + import torchtitan.models.moe.moe - from primus.backends.torchtitan.models.moe.moe import _run_experts_grouped_mm - + + from primus.backends.torchtitan.models.moe.moe import ( + _run_experts_grouped_mm, + ) + # Get MoE FP8 configuration and create a partial function use_moe_fp8 = self.titan_config.primus_turbo.use_moe_fp8 primus_logger.info(f"Set MoE FP8 mode: {use_moe_fp8}") - + # Patch the grouped_mm function with use_fp8 parameter pre-set torchtitan.models.moe.moe._run_experts_grouped_mm = functools.partial( _run_experts_grouped_mm, use_fp8=use_moe_fp8 @@ -125,6 +132,33 @@ def patch_torchtitan_moe(self): f"Original error: {e}" ) from e + def patch_classic_attention(self): + if not self.titan_config.primus_turbo.use_classic_attention: + return + + from primus.core.utils.logger import _logger as primus_logger + + primus_logger.info("Monkey patch classic attention...") + + import torchtitan.models.deepseek_v3 + + from primus.backends.torchtitan.models.deepseek_v3 import ( + classic_deepseekv3_args, + ) + from primus.backends.torchtitan.models.deepseek_v3.model.args import ( + DeepSeekV3ClassicModelArgs, + ) + from primus.backends.torchtitan.models.deepseek_v3.model.model import ( + MultiHeadAttention, + ) + + torchtitan.models.deepseek_v3.deepseekv3_args = classic_deepseekv3_args + torchtitan.models.deepseek_v3.DeepSeekV3ModelArgs = DeepSeekV3ClassicModelArgs + + import torchtitan.models.deepseek_v3.model.model + + torchtitan.models.deepseek_v3.model.model.Attention = MultiHeadAttention + def patch_torch_dcp_consolidate(self): """ Monkey patch for torch.distributed.checkpoint._consolidate_hf_safetensors @@ -286,10 +320,19 @@ def enable_primus_turbo_extension(self): torchtitan.models.llama3.model.model.Attention = Attention + # ******* llama4 Attention Model ******* + import torchtitan.models.llama4.model.model + + from primus.backends.torchtitan.models.llama4.model.model import Attention + + torchtitan.models.llama4.model.model.Attention = Attention + # ******* deepseek_v3 Attention Model ******* import torchtitan.models.deepseek_v3.model.model - from primus.backends.torchtitan.models.deepseek_v3.model.model import Attention + from primus.backends.torchtitan.models.deepseek_v3.model.model import ( + Attention, + ) torchtitan.models.deepseek_v3.model.model.Attention = Attention logger.warning(f"TorchtitanPretrainTrainer: Patch Turbo Attention") From 18f37be17c93b96c3206097b6e1288a943a72704 Mon Sep 17 00:00:00 2001 From: liyingli Date: Sat, 22 Nov 2025 09:29:20 +0000 Subject: [PATCH 25/32] update classic attention args for deepseek_v3 --- .../torchtitan/models/deepseek_v3/__init__.py | 13 +++++++------ .../torchtitan/models/deepseek_v3/model/args.py | 8 ++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/primus/backends/torchtitan/models/deepseek_v3/__init__.py b/primus/backends/torchtitan/models/deepseek_v3/__init__.py index 095917b9d..6ec1c8600 100644 --- a/primus/backends/torchtitan/models/deepseek_v3/__init__.py +++ b/primus/backends/torchtitan/models/deepseek_v3/__init__.py @@ -35,8 +35,9 @@ mscale=0.70, use_flex_attn=True, attn_mask_type="block_causal", + q_head=16, n_kv_heads=16, - head_dim=2048 // 128, + head_dim=128, ), "236B": DeepSeekV3ClassicModelArgs( vocab_size=102400, @@ -73,8 +74,8 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=24, - n_dense_layers=0, + n_layers=61, + n_dense_layers=3, n_heads=128, moe_args=MoEArgs( num_experts=256, @@ -94,8 +95,8 @@ v_head_dim=128, use_flex_attn=True, attn_mask_type="block_causal", - q_head=40, - n_kv_heads=128, - head_dim=16, + q_head=56, + n_kv_heads=8, + head_dim=128, ), } diff --git a/primus/backends/torchtitan/models/deepseek_v3/model/args.py b/primus/backends/torchtitan/models/deepseek_v3/model/args.py index 526bcefb8..501189d54 100644 --- a/primus/backends/torchtitan/models/deepseek_v3/model/args.py +++ b/primus/backends/torchtitan/models/deepseek_v3/model/args.py @@ -14,7 +14,7 @@ @dataclass class DeepSeekV3ClassicModelArgs(DeepSeekV3ModelArgs): # Classical Attention - n_heads: int = 128 - q_head: int = 40 - n_kv_heads: int = 8 - head_dim: int = 2048 // 128 + n_heads: int = 16 + q_head: int = 16 + n_kv_heads: int = 16 + head_dim: int = 128 From a835bb41d2c6d08bad974b8299ff86ef97b6b72f Mon Sep 17 00:00:00 2001 From: liyingli Date: Sat, 22 Nov 2025 14:24:53 +0000 Subject: [PATCH 26/32] update config --- .../configs/MI300X/deepseek_v3_671b-pretrain.yaml | 6 +++--- primus/configs/models/torchtitan/deepseek_v3_671b.yaml | 2 +- run_titan_dsv2_lite.sh | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml index 2025a01ed..6b8d2f435 100644 --- a/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml @@ -39,7 +39,7 @@ modules: training: debug_moe_force_load_balance: true - local_batch_size: 14 + local_batch_size: 16 seq_len: 4096 max_norm: 1.0 # grad norm clipping steps: 15 @@ -69,13 +69,13 @@ modules: selective_ac_option: "op" # 'int' = ac every positive int layer or 'op', ac based on ops policy compile: - enable: false + enable: true components: ["model", "loss"] # ["model", "loss"] primus_turbo: enable_primus_turbo: true use_turbo_mx_linear: false - use_turbo_float8_linear: false + use_turbo_float8_linear: true enable_attention_float8: false use_classic_attention: true use_turbo_grouped_mm: true diff --git a/primus/configs/models/torchtitan/deepseek_v3_671b.yaml b/primus/configs/models/torchtitan/deepseek_v3_671b.yaml index a5fbe2125..0bdf4cf7f 100644 --- a/primus/configs/models/torchtitan/deepseek_v3_671b.yaml +++ b/primus/configs/models/torchtitan/deepseek_v3_671b.yaml @@ -7,5 +7,5 @@ model: name: "deepseek_v3" flavor: "671B" hf_assets_path: "deepseek-ai/DeepSeek-V3.1-Base" - # converters: ["turbo_fp8_linear"] + converters: ["turbo_fp8_linear"] # converters: ["float8"] diff --git a/run_titan_dsv2_lite.sh b/run_titan_dsv2_lite.sh index 417e32b6a..41965f691 100755 --- a/run_titan_dsv2_lite.sh +++ b/run_titan_dsv2_lite.sh @@ -9,12 +9,12 @@ export NCCL_IB_HCA="bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_r # export AINIC_LIB="/apps/gpuperf/ainic-driver-20251007/lib/" export ANP_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/amd-anp-1.1.0-5" export RCCL_HOME_DIR="/shared/apps/ubuntu/rocm-7.0.1/rccl-drop-2025-08" -export NCCL_SOCKET_IFNAME="lo" -export GLOO_SOCKET_IFNAME="lo" +# export NCCL_SOCKET_IFNAME="lo" +# export GLOO_SOCKET_IFNAME="lo" export DOCKER_IMAGE="docker.io/rocm/pyt-megatron-lm-jax-nightly-private:primus_rocm7.1_20251117" export CPUS_PER_TASK=128 -export HSA_NO_SCRATCH_RECLAIM=1 +export HSA_NO_SCRATCH_RECLAIM=1 export NVTE_CK_USES_BWD_V3=1 # export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_16b-pretrain.yaml" @@ -64,4 +64,4 @@ EXPORT_CONFIG=$LOG_DIR/config.yaml # bash ./examples/run_pretrain.sh 2>&1 | tee $LOG_FILE export EXP="examples/torchtitan/configs/MI300X/deepseek_v3_671b-pretrain.yaml" -bash ./examples/run_pretrain.sh --model.n_layers 4 --model.n_dense_layers 1 2>&1 | tee $LOG_FILE +bash ./examples/run_pretrain.sh --model.n_layers 4 --model.n_dense_layers 0 2>&1 | tee $LOG_FILE From 48fe12894c1996daa3b92fb4e07861c0e949dfd8 Mon Sep 17 00:00:00 2001 From: JohnQinAMD Date: Tue, 25 Nov 2025 22:53:23 +0000 Subject: [PATCH 27/32] support new turo fp8 api --- .../megatron/core/extensions/primus_turbo.py | 11 +++++- primus/backends/megatron/core/fp8_utils.py | 35 +++++++++++++++---- .../components/quantization/float8.py | 8 ++++- .../torchtitan/components/quantization/mx.py | 8 ++++- primus/backends/torchtitan/models/moe/moe.py | 10 +++++- 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/primus/backends/megatron/core/extensions/primus_turbo.py b/primus/backends/megatron/core/extensions/primus_turbo.py index fa5998417..b52859fb9 100644 --- a/primus/backends/megatron/core/extensions/primus_turbo.py +++ b/primus/backends/megatron/core/extensions/primus_turbo.py @@ -29,7 +29,16 @@ from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none from megatron.training.global_vars import get_args -from primus_turbo.pytorch.core.float8 import ( + +try: + from primus_turbo.pytorch.core.float8 import ( + Float8QuantConfig, + ScalingGranularity, + ScalingStrategy, + check_fp8_support, + ) +except ImportError: + from primus_turbo.pytorch.core.low_precision import ( Float8QuantConfig, ScalingGranularity, ScalingStrategy, diff --git a/primus/backends/megatron/core/fp8_utils.py b/primus/backends/megatron/core/fp8_utils.py index ff69c67dd..50ad3349e 100644 --- a/primus/backends/megatron/core/fp8_utils.py +++ b/primus/backends/megatron/core/fp8_utils.py @@ -42,14 +42,22 @@ from megatron.core import parallel_state from megatron.core.enums import Fp8Recipe from megatron.core.extensions.transformer_engine import TEDelayedScaling - from primus_turbo.pytorch.core.float8 import ScalingGranularity + try: + from primus_turbo.pytorch.core.float8 import ScalingGranularity + except ImportError: + from primus_turbo.pytorch.core.low_precision import ScalingGranularity + from primus.backends.megatron.core.extensions.primus_turbo import ( PrimusTurboFloat8QuantConfig, ) def te_fp8_format_mapping(te_format): - from primus_turbo.pytorch.core.float8 import Format as TurboFormat + try: + from primus_turbo.pytorch.core.float8 import Format as TurboFormat + except ImportError: + from primus_turbo.pytorch.core.low_precision import Format as TurboFormat + # noqa: F811 from transformer_engine.common.recipe import Format as TEFormat format_mapping = { @@ -194,7 +202,10 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool elif HAVE_TURBO: from megatron.core import parallel_state from megatron.core.enums import Fp8Recipe - from primus_turbo.pytorch.core.float8 import ScalingGranularity + try: + from primus_turbo.pytorch.core.float8 import ScalingGranularity + except ImportError: + from primus_turbo.pytorch.core.low_precision import ScalingGranularity from primus.backends.megatron.core.extensions.primus_turbo import ( PrimusTurboFloat8QuantConfig, @@ -234,10 +245,20 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool # fp8 training and this layer_no is in fp8 import primus_turbo - if config.fp8 == "e4m3": - fp8_format = primus_turbo.pytorch.core.float8.Format.E4M3 - elif config.fp8 == "hybrid": - fp8_format = primus_turbo.pytorch.core.float8.Format.HYBRID + # Pick the right Format enum once + try: + # Older API + from primus_turbo.pytorch.core.float8 import Format as FP8Format + except ImportError: + # Newer API + from primus_turbo.pytorch.core.low_precision import Format as FP8Format + + fp8_str = config.fp8.lower() + + if fp8_str == "e4m3": + fp8_format = FP8Format.E4M3 + elif fp8_str == "hybrid": + fp8_format = FP8Format.HYBRID else: raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") diff --git a/primus/backends/torchtitan/components/quantization/float8.py b/primus/backends/torchtitan/components/quantization/float8.py index 28defbaa6..230447620 100644 --- a/primus/backends/torchtitan/components/quantization/float8.py +++ b/primus/backends/torchtitan/components/quantization/float8.py @@ -6,7 +6,13 @@ import torch import torch.nn as nn -from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity + +# Compatibility for different primus_turbo versions +try: + from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity +except ImportError: + from primus_turbo.pytorch.core.low_precision import Float8QuantConfig, ScalingGranularity + from primus_turbo.pytorch.modules.linear_fp8 import Float8Linear from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims diff --git a/primus/backends/torchtitan/components/quantization/mx.py b/primus/backends/torchtitan/components/quantization/mx.py index 705bc47ea..6b1e54e5d 100644 --- a/primus/backends/torchtitan/components/quantization/mx.py +++ b/primus/backends/torchtitan/components/quantization/mx.py @@ -6,7 +6,13 @@ import torch import torch.nn as nn -from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity + +# Compatibility for different primus_turbo versions +try: + from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity +except ImportError: + from primus_turbo.pytorch.core.low_precision import Float8QuantConfig, ScalingGranularity + from primus_turbo.pytorch.modules.linear_fp8 import Float8Linear from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims diff --git a/primus/backends/torchtitan/models/moe/moe.py b/primus/backends/torchtitan/models/moe/moe.py index e661e997e..6b79db3cc 100644 --- a/primus/backends/torchtitan/models/moe/moe.py +++ b/primus/backends/torchtitan/models/moe/moe.py @@ -3,7 +3,15 @@ import torch import torch.nn.functional as F import primus_turbo.pytorch as turbo -from primus_turbo.pytorch.core.float8 import ( +# tyr to load primus_turbo.pytorch.core.float8, but it is not found +try: + from primus_turbo.pytorch.core.float8 import ( + Float8QuantConfig, + Format, + ScalingGranularity, + ) +except ImportError: + from primus_turbo.pytorch.core.low_precision import ( Float8QuantConfig, Format, ScalingGranularity, From ad430318e35a3b108f261342d1519059d382a937 Mon Sep 17 00:00:00 2001 From: JohnQinAMD Date: Wed, 26 Nov 2025 00:04:52 +0000 Subject: [PATCH 28/32] support install turbo from source --- examples/run_local_pretrain.sh | 1 + examples/run_pretrain.sh | 24 ++++++++++++++++++++++++ examples/run_slurm_pretrain.sh | 1 + 3 files changed, 26 insertions(+) diff --git a/examples/run_local_pretrain.sh b/examples/run_local_pretrain.sh index 1a2267969..8404cf968 100755 --- a/examples/run_local_pretrain.sh +++ b/examples/run_local_pretrain.sh @@ -143,6 +143,7 @@ docker_podman_proxy run --rm \ --env TORCHTITAN_PATH \ --env MAXTEXT_PATH \ --env BACKEND_PATH \ + --env REBUILD_PRIMUS_TURBO \ "${ENV_ARGS[@]}" \ --ipc=host --network=host \ --device=/dev/kfd --device=/dev/dri \ diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index 1ff42ef79..b3b45a26a 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -277,6 +277,30 @@ export NVTE_CK_USES_BWD_V3=${NVTE_CK_USES_BWD_V3:-0} # Note: Disable fp32 atomic due if you find any accuracy issue. export PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32=${PRIMUS_TURBO_ATTN_V3_ATOMIC_FP32:-0} +# install primus turbo from source +export REBUILD_PRIMUS_TURBO=${REBUILD_PRIMUS_TURBO:-0} +if [ "$REBUILD_PRIMUS_TURBO" == "1" ]; then + LOG_INFO "Rebuilding Primus Turbo from source..." + mkdir -p "/workspace/turbo" + cd "/workspace/turbo" + + # Clean up old directory if exists to avoid git clone conflicts + if [ -d "Primus-Turbo" ]; then + LOG_INFO "Removing existing Primus-Turbo directory..." + rm -rf Primus-Turbo + fi + + git clone https://github.com/AMD-AGI/Primus-Turbo.git --recursive + cd Primus-Turbo + pip3 install -r requirements.txt + # Set GPU_ARCHS to compile Turbo for multiple AMD GPU architectures. + GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation . + cd "${PRIMUS_PATH}" + LOG_INFO "Rebuilding Primus Turbo from source done." +else + LOG_INFO "Skip Primus Turbo rebuild. REBUILD_PRIMUS_TURBO=$REBUILD_PRIMUS_TURBO" +fi + # nvte debug envs export NVTE_DEBUG=0 # 0, 1 export NVTE_DEBUG_LEVEL=0 # 0, 1, 2 diff --git a/examples/run_slurm_pretrain.sh b/examples/run_slurm_pretrain.sh index 431c547b2..bb00afd5d 100755 --- a/examples/run_slurm_pretrain.sh +++ b/examples/run_slurm_pretrain.sh @@ -57,5 +57,6 @@ srun -N "${NNODES}" \ export NNODES=\${SLURM_NNODES} export NODE_RANK=\${SLURM_PROCID} export GPUS_PER_NODE=\${SLURM_GPUS_ON_NODE} + export REBUILD_PRIMUS_TURBO=\${REBUILD_PRIMUS_TURBO} bash ${SCRIPT_DIR}/run_local_pretrain.sh \"\$@\" 2>&1 | tee ${LOG_FILE} " bash "$@" From f89348dd248a0d6dedb5550b5f35ca48b98f0f72 Mon Sep 17 00:00:00 2001 From: JohnQinAMD Date: Wed, 26 Nov 2025 00:31:41 +0000 Subject: [PATCH 29/32] add dsv3 config for MI355 --- .../MI355X/deepseek_v3_16b-pretrain.yaml | 19 +++++++++++---- .../MI355X/deepseek_v3_671b-pretrain.yaml | 24 +++++++++++++------ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/examples/torchtitan/configs/MI355X/deepseek_v3_16b-pretrain.yaml b/examples/torchtitan/configs/MI355X/deepseek_v3_16b-pretrain.yaml index 95e1a80cf..75e1a370d 100644 --- a/examples/torchtitan/configs/MI355X/deepseek_v3_16b-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/deepseek_v3_16b-pretrain.yaml @@ -13,14 +13,14 @@ modules: model: deepseek_v3_16b.yaml overrides: profiling: - enable_profiling: false + enable_profiling: true save_traces_folder: "profile_trace" profile_freq: 10 enable_memory_snapshot: false save_memory_snapshot_folder: "memory_snapshot" metrics: - log_freq: 10 + log_freq: 1 disable_color_printing: false enable_tensorboard: false save_tb_folder: "tb" @@ -38,11 +38,12 @@ modules: min_lr_factor: 0.1 training: + debug_moe_force_load_balance: true local_batch_size: 4 seq_len: 4096 max_norm: 1.0 # grad norm clipping - steps: 1000 - dataset: "c4" # supported datasets: c4_test (2K), c4 (177M) + steps: 15 + dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M) parallelism: data_parallel_replicate_degree: 1 @@ -69,8 +70,16 @@ modules: compile: enable: true - components: ["loss"] # ["model", "loss"] + components: ["model", "loss"] # ["model", "loss"] + primus_turbo: + enable_primus_turbo: true + use_turbo_mx_linear: false + use_turbo_float8_linear: true + enable_attention_float8: false + use_turbo_grouped_mm: true + use_moe_fp8: false + # quantize: # linear: # float8: diff --git a/examples/torchtitan/configs/MI355X/deepseek_v3_671b-pretrain.yaml b/examples/torchtitan/configs/MI355X/deepseek_v3_671b-pretrain.yaml index b340538be..6b8d2f435 100644 --- a/examples/torchtitan/configs/MI355X/deepseek_v3_671b-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/deepseek_v3_671b-pretrain.yaml @@ -13,14 +13,14 @@ modules: model: deepseek_v3_671b.yaml overrides: profiling: - enable_profiling: false + enable_profiling: true save_traces_folder: "profile_trace" profile_freq: 10 enable_memory_snapshot: false save_memory_snapshot_folder: "memory_snapshot" metrics: - log_freq: 10 + log_freq: 1 disable_color_printing: false enable_tensorboard: false save_tb_folder: "tb" @@ -38,11 +38,12 @@ modules: min_lr_factor: 0.1 training: - local_batch_size: 4 + debug_moe_force_load_balance: true + local_batch_size: 16 seq_len: 4096 max_norm: 1.0 # grad norm clipping - steps: 1000 - dataset: "c4" # supported datasets: c4_test (2K), c4 (177M) + steps: 15 + dataset: "c4_test" # supported datasets: c4_test (2K), c4 (177M) parallelism: data_parallel_replicate_degree: 1 @@ -52,7 +53,7 @@ modules: enable_async_tensor_parallel: false pipeline_parallel_degree: 1 pipeline_parallel_schedule: "Interleaved1F1B" - expert_parallel_degree: 1 + expert_parallel_degree: 8 expert_tensor_parallel_degree: 1 checkpoint: @@ -69,7 +70,16 @@ modules: compile: enable: true - components: ["loss"] # ["model", "loss"] + components: ["model", "loss"] # ["model", "loss"] + + primus_turbo: + enable_primus_turbo: true + use_turbo_mx_linear: false + use_turbo_float8_linear: true + enable_attention_float8: false + use_classic_attention: true + use_turbo_grouped_mm: true + use_moe_fp8: false # quantize: # linear: From c89266090560a93cd14f8e703e44eb4d6e12ad03 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Wed, 10 Dec 2025 00:34:53 +0000 Subject: [PATCH 30/32] initial commit --- .../configs/MI300X/mamba_370M-pretrain.yaml | 94 +++++++++++++++++++ .../MI300X/mamba_hybrid_2.8B-pretrain.yaml | 84 +++++++++++++++++ .../models/megatron/language_model.yaml | 1 + .../configs/models/megatron/mamba_1.4B.yaml | 15 +++ .../configs/models/megatron/mamba_370M.yaml | 15 +++ .../configs/models/megatron/mamba_base.yaml | 41 ++++++++ .../models/megatron/mamba_hybrid_2.8B.yaml | 29 ++++++ primus/core/utils/import_utils.py | 41 +++++--- .../trainer/lightmegatron/pre_trainer.py | 41 ++++++-- primus/modules/trainer/megatron/trainer.py | 8 +- 10 files changed, 344 insertions(+), 25 deletions(-) create mode 100644 examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml create mode 100644 examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml create mode 100644 primus/configs/models/megatron/mamba_1.4B.yaml create mode 100644 primus/configs/models/megatron/mamba_370M.yaml create mode 100644 primus/configs/models/megatron/mamba_base.yaml create mode 100644 primus/configs/models/megatron/mamba_hybrid_2.8B.yaml diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml new file mode 100644 index 000000000..f1a60cde0 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -0,0 +1,94 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_370M-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_370M.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Pretrain" + # disable_wandb: false + # disable_tensorboard: false + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 50 + micro_batch_size: 4 + global_batch_size: 256 + + seq_length: 2048 + max_position_embeddings: 2048 + + lr: 3.0e-4 + min_lr: 0.0 + lr_warmup_iters: 100 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # Mamba SSM parameters + is_hybrid_model: false + hybrid_attention_ratio: 0.0 + hybrid_mlp_ratio: 0.0 + mamba_state_dim: 16 + mamba_head_dim: 64 + mamba_num_groups: 8 + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + + # Turbo - may need to disable for Mamba if not supported + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + # Cross entropy flags + # cross_entropy_fusion_impl: "native" + # cross_entropy_loss_fusion: false + diff --git a/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml new file mode 100644 index 000000000..58d800401 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_hybrid_2.8B-pretrain.yaml @@ -0,0 +1,84 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_hybrid_2.8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_hybrid_2.8B.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Hybrid_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 128 + + seq_length: 4096 + max_position_embeddings: 4096 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # Hybrid Mamba+Attention parameters + is_hybrid_model: true + hybrid_attention_ratio: 0.125 + hybrid_mlp_ratio: 0.0 + mamba_state_dim: 16 + mamba_head_dim: 64 + mamba_num_groups: 8 + + # parallel + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: true + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch + + # Turbo - disable for Mamba layers, but attention layers may benefit + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index 3226f89c6..c3fe17579 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -8,6 +8,7 @@ includes: # model architecture use_legacy_models: false deprecated_use_mcore_models: false +model_type: gpt # gpt or mamba num_layers: 24 encoder_num_layers: null decoder_num_layers: null diff --git a/primus/configs/models/megatron/mamba_1.4B.yaml b/primus/configs/models/megatron/mamba_1.4B.yaml new file mode 100644 index 000000000..91e255443 --- /dev/null +++ b/primus/configs/models/megatron/mamba_1.4B.yaml @@ -0,0 +1,15 @@ +bases: + - mamba_base.yaml + +# Mamba 1.4B configuration + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 2048 +ffn_hidden_size: null + +max_position_embeddings: 2048 + diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml new file mode 100644 index 000000000..b10523b2e --- /dev/null +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -0,0 +1,15 @@ +bases: + - mamba_base.yaml + +# Mamba 370M configuration + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 1024 +ffn_hidden_size: null + +max_position_embeddings: 2048 + diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml new file mode 100644 index 000000000..f658c8371 --- /dev/null +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -0,0 +1,41 @@ +bases: + - language_model.yaml + +# Mamba-specific configuration +# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) +# must be set in the pretrain config overrides, not here + +model_type: mamba +use_legacy_models: false + +# Position embeddings - Mamba typically doesn't use position embeddings +position_embedding_type: rope +use_rotary_position_embeddings: false + +# Tokenizer (should be set in specific model configs) +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: null + +# Model architecture +num_layers: 24 +hidden_size: 1024 +ffn_hidden_size: null + +# Standard transformer settings that may be used by hybrid models +num_attention_heads: 16 +attention_dropout: 0.0 +hidden_dropout: 0.0 + +# Embeddings +untie_embeddings_and_output_weights: true + +# Other settings +apply_residual_connection_post_layernorm: false +add_bias_linear: false +swiglu: false + +# Normalization +norm_epsilon: 1.0e-5 + +# Initialization +init_method_std: 0.02 diff --git a/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml b/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml new file mode 100644 index 000000000..f2fd20cba --- /dev/null +++ b/primus/configs/models/megatron/mamba_hybrid_2.8B.yaml @@ -0,0 +1,29 @@ +bases: + - mamba_base.yaml + +# Mamba 2.8B configuration with hybrid attention layers + +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 64 +hidden_size: 2560 +ffn_hidden_size: 6827 # ~2.67x hidden_size + +# Attention parameters (for hybrid layers) +num_attention_heads: 32 +group_query_attention: true +num_query_groups: 8 + +# Hybrid configuration: override mamba_base defaults +hybrid_attention_ratio: 0.125 +is_hybrid_model: true + +# For hybrid models, position embeddings may be useful +position_embedding_type: rope +rotary_base: 10000 +rotary_percent: 1.0 + +max_position_embeddings: 4096 + diff --git a/primus/core/utils/import_utils.py b/primus/core/utils/import_utils.py index 2ccd8ebed..6e67de8a4 100644 --- a/primus/core/utils/import_utils.py +++ b/primus/core/utils/import_utils.py @@ -34,25 +34,40 @@ def lazy_import(paths, symbol, log_prefix="[Primus]"): raise ImportError(f"{log_prefix} {symbol} not found in any of: {paths}") -def get_model_provider(): +def get_model_provider(model_type="gpt"): """ - Resolve model_provider across Megatron versions. + Resolve model_provider across Megatron versions and model types. - - New: model_provider + gpt_builder + Args: + model_type (str): Type of model - 'gpt' or 'mamba'. Defaults to 'gpt'. + + - New: model_provider + gpt_builder/mamba_builder - Mid: model_provider only - - Old: pretrain_gpt.model_provider + - Old: pretrain_gpt.model_provider / pretrain_mamba.model_provider """ # Try to import model_provider - model_provider = lazy_import( - ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" - ) + if model_type == "mamba": + model_provider = lazy_import( + ["model_provider", "pretrain_mamba"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) + # Try to import mamba_builder (for Mamba models) + try: + mamba_builder = lazy_import(["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, mamba_builder) + except ImportError: + return model_provider + else: + # Default GPT behavior + model_provider = lazy_import( + ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) - # Try to import gpt_builder (only exists in newer versions) - try: - gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") - return partial(model_provider, gpt_builder) - except ImportError: - return model_provider + # Try to import gpt_builder (only exists in newer versions) + try: + gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, gpt_builder) + except ImportError: + return model_provider def get_custom_fsdp(): diff --git a/primus/modules/trainer/lightmegatron/pre_trainer.py b/primus/modules/trainer/lightmegatron/pre_trainer.py index 0a12460fa..973421933 100644 --- a/primus/modules/trainer/lightmegatron/pre_trainer.py +++ b/primus/modules/trainer/lightmegatron/pre_trainer.py @@ -38,15 +38,36 @@ def run(self, *args, **kwargs): from megatron.core.enums import ModelType from megatron.training import inprocess_restart, pretrain - from pretrain_gpt import forward_step, train_valid_test_datasets_provider + from megatron.training import get_args - train_valid_test_datasets_provider.is_distributed = True - wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + # Determine model type from config (gpt or mamba) + megatron_args = get_args() + model_type = getattr(megatron_args, 'model_type', 'gpt') + log_rank_0(f"Detected model_type: {model_type}") - wrapped_pretrain( - train_valid_test_datasets_provider, - get_model_provider(), - ModelType.encoder_or_decoder, - forward_step, - store=store, - ) + if model_type == 'mamba': + # Import from pretrain_mamba + from pretrain_mamba import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type='mamba'), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) + else: + # Default to GPT + from pretrain_gpt import forward_step, train_valid_test_datasets_provider + train_valid_test_datasets_provider.is_distributed = True + wrapped_pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + wrapped_pretrain( + train_valid_test_datasets_provider, + get_model_provider(model_type='gpt'), + ModelType.encoder_or_decoder, + forward_step, + store=store, + ) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 4eae5e131..7c30007e8 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -962,11 +962,15 @@ def update_primus_config( args.valid_data_path = None args.test_data_path = None + # Determine model type (gpt or mamba) + model_type = getattr(args, 'model_type', 'gpt') + log_rank_0(f"-detected model_type: {model_type}") + if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial(primus_model_provider, get_model_provider()) + self.model_provider = functools.partial(primus_model_provider, get_model_provider(model_type=model_type)) else: - self.model_provider = get_model_provider() + self.model_provider = get_model_provider(model_type=model_type) if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}") From d8ae27d0af150181c92e81c990d09256f5a93823 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Tue, 16 Dec 2025 14:05:15 +0000 Subject: [PATCH 31/32] set self.lr_warmup_steps < self.lr_decay_steps --- examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml index f1a60cde0..9557e7860 100644 --- a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -31,8 +31,8 @@ modules: lr: 3.0e-4 min_lr: 0.0 - lr_warmup_iters: 100 - lr_decay_iters: null + lr_warmup_iters: 50000 + lr_decay_iters: 73192188 lr_decay_style: cosine weight_decay: 0.1 adam_beta1: 0.9 From df0b00e2d360266b22d700421c5b151b396883f7 Mon Sep 17 00:00:00 2001 From: clairesonglee Date: Thu, 18 Dec 2025 15:58:30 +0000 Subject: [PATCH 32/32] unwrap model to remove loss_mask parameter --- .../modules/trainer/megatron/pre_trainer.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/primus/modules/trainer/megatron/pre_trainer.py b/primus/modules/trainer/megatron/pre_trainer.py index 7eefaf669..7dc2f3eec 100644 --- a/primus/modules/trainer/megatron/pre_trainer.py +++ b/primus/modules/trainer/megatron/pre_trainer.py @@ -240,6 +240,20 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals assert ( args.overlap_moe_expert_parallel_comm ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + + # Schedule plan building is only supported for GPT models + # Check if this is a Mamba model + unwrapped_model = model + while hasattr(unwrapped_model, 'module'): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if 'Mamba' in model_class_name: + raise NotImplementedError( + "Schedule plan building is not supported for Mamba models. " + "Please disable overlap_moe_expert_parallel_comm for Mamba." + ) + if args.patch_moe_overlap: assert ( not args.delay_wgrad_compute @@ -265,8 +279,23 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals ) return schedule_plan, partial(self.loss_func, loss_mask) else: - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask - ) + # Check if model supports loss_mask parameter + # MambaModel doesn't accept loss_mask, but GPTModel does + # Unwrap the model to get the actual model class + unwrapped_model = model + while hasattr(unwrapped_model, 'module'): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if 'Mamba' in model_class_name: + # MambaModel doesn't accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels + ) + else: + # GPTModel and other models accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) return output_tensor, partial(self.loss_func, loss_mask)