Merge upstream 2026 10 02#60
Merged
dmahan93 merged 199 commits intodev-updated-againfrom Mar 13, 2026
Merged
Conversation
We should be able to control what passes to run in the compiler. This PR uses the config compile.passes to indicate in a list of graph passes to apply on the captured gm. By default, no pass is applied. Users can specify what passes to apply. Currently there are `autobucketing_reordering_pass` and `regional_inductor_pass`. ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` Also updated CI to include this new config
After some offline discussion, we've concluded that life would be easier if we can put simplefsdp's checkpoint logic for `reshard_after_forward` to compiler. The ac annotation part is borrowed form AP: [LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69). **Trace and Loss Check** (all with torch.compile enable) reshard_after_fwd = False 1. SAC + llama3 ([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json)) <img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM" src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a" /> <img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM" src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c" /> 2. Full AC + llama3 [(trace)]() <img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM" src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548" /> <img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM" src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5" /> 3. No AC + llama3 [[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)] <img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM" src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0" /> <img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM" src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af" /> reshard_after_fwd = True 1. SAC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json)) <img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM" src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4" /> 2. Full AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json)) <img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM" src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca" /> 3. No AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json)) <img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM" src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc" />
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#2002 * __->__ pytorch#2001 Add typing, credit to Claude.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ pytorch#2013 When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. **This is a reland PR of pytorch#2002. The previous one was broken during rebase.**
Summary: - we need to pass the global rank information to pytorch so that the pg name can include the pg information - this is necessary to differentiate the default pg's on different replicas - these need to different because flight recorder matches collectives based on pg name as well - add ft training to experiments folder, we'll move remaining pieces of ft to this gradually but make new features only available through this folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986). * pytorch#1988 * pytorch#1987 * __->__ pytorch#1986 Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2012 * __->__ pytorch#2011 It is not correct as JobConfig has changed.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ pytorch#2012 * pytorch#2011 Summary: The current configuration validation requires torchx and GPUs. It can waste time, resources, ane engery. Polar bears are crying. Let's fix this by providing a dry run mode. This PR doesn't verify everything. In theory, we should be able to verify parallelisms settings as well. This PR is just a start but it at least can let us catch the typos quickly.
As titled. `_clear_traced_params_buffers` is no longer being used as we have switched the dynamo graph capture API.
…nly (pytorch#2016) Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment)
…& push to Main branch only" (pytorch#2017) Reverts PR: pytorch#2016 Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment) Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
…2015) This PR adds the utils to automatically check the training numerics (losses, grad norms) of two runs to verify if they have bitwise equivalence. The added script triggers two runs with user defined configs. Then it loads metrics saved during training and compare the numerics to verify bitwise equivalence. Currently we check for losses and grad norms during training steps For example, we want to compare the numerics between compiler toolkit with aot_eager backend and eager on llama3-8B. ``` python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2 ``` It'll run `simple_fsdp` experiment without `torch.compile` as the eager baseline, and `compile_toolkit` experiment as the compiled run. Then it compares the training numerics of these two runs to verify bitwise equivalence. When it is bitwise equivalent, we'll see the following output ``` Starting training: simple_fsdp.llama3 ✓ Training completed: simple_fsdp.llama3 Starting training: compiler_toolkit.llama3 ✓ Training completed: compiler_toolkit.llama3 ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ SUCCESS: All metrics are bitwise equivalent ``` Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so that we can guard working parallelism combinations that already have bitwise equivalence in CI.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * pytorch#2030 * pytorch#2028 * pytorch#2027 * __->__ pytorch#2026 As title
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * pytorch#2030 * pytorch#2028 * __->__ pytorch#2027 * pytorch#2026 Dry run mode works but it doesn't exit gracefully for all cases. This PR fixes it ``` DRY_RUN=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --activation_checkpoint.mode="none" --debug.deterministic --debug.seed=42 ```
…h#2030) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * __->__ pytorch#2030 The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. Also merge pytorch#2028 to this PR. Something wrong with ghstack.
Fixes the badge in the `README.md` file
Before: <img width="978" height="93" alt="image" src="https://github.com/user-attachments/assets/48dc39d9-e897-4396-ac62-025574303403" /> After: <img width="1318" height="82" alt="image" src="https://github.com/user-attachments/assets/47b4771a-aaf9-4f61-80bc-757f3a08c1d2" />
This PR adds support for aten-level manual bucketing in SimpleFSDP+`aot_eager` backend. Dependent on PyTorch [PR](pytorch/pytorch#165487) TODO List: - [ ] We should have better way of handling region info other than a list of str FQNs in current `manual_bucketed_modules`. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad ) - [ ] Currently, the reordering happens under the hood and overlap with last/next compute. We should allow users to specify which module they want to reorder. - [ ] Loss difference on multi-node training - [ ] DSV3 manual bucketing I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR. 1. Performance (FSDP2 under eager mode, SimpleFSDP uses `aot_eager` backend) **Llama 3-8B** * Performance (All Batch_size = 1). (The slower TPS on Single Node is sort of as expected, since FSDP2 handles copy-in/out in two different streams, whereas SimpleFSDP handles copy-in/out in the same stream) |Node| Method | Parallelism | Memory | TPS | Trace| |---------|---------|-----------|----------|------|------| |1-Node (8H100)|SimpleFSDP | FSDP=8| 40.96GiB(43.12%) | 7,227| [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-48-48_rank0_trace.json)| |1-Node (8H100)|FSDP2-eager| FSDP=8| 47.82GiB(50.35%) | 7,380 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-54-14_rank0_trace.json)| |8-Node (64H100)|SimpleFSDP| FSDP=64 | 29.37GiB | 4,984| | |8-Node (64H100)|FSDP2| FSDP=64 | 31.41GiB |5,097 | | |1-Node (8H100)|SimpleFSDP| FSDP=4 TP=2 | 28.28GiB(29.77%) | 5,881 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-18-00-18_rank0_trace.json) | |1-Node (8H100)|FSDP2| FSDP=4 TP=2 | 35.33GiB(37.20%) | 5,898 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-15-35-47_rank0_trace.json) | |8-Node (64H100)|SimpleFSDP| FSDP=8 TP=8 | ||| |8-Node (64H100)|FSDP2| FSDP=8 TP=8 | ||| Example SimpleFSDP 1D overlapping trace: <img width="1127" height="127" alt="Screenshot 2025-10-16 at 10 49 55 AM" src="https://github.com/user-attachments/assets/2d9e3ff8-8e9b-40a7-a666-3c0a0975186e" /> Example SimpleFSDP 2D overlapping trace: <img width="1162" height="166" alt="Screenshot 2025-10-26 at 6 00 51 PM" src="https://github.com/user-attachments/assets/bc5cc031-5b6c-4e4d-a9da-70c43114f49a" /> - Bitwise Loss: FSDP-only: <img width="1266" height="837" alt="Screenshot 2025-10-17 at 10 41 56 AM" src="https://github.com/user-attachments/assets/30f83d95-1eca-4f10-9e7e-47c45278cd8d" /> FSDP+TP: <img width="1259" height="808" alt="Screenshot 2025-10-26 at 9 03 58 PM" src="https://github.com/user-attachments/assets/b75b452b-adb9-4078-9412-ee9e584ffe15" />
The current `convert_to_hf.py` does not support `export_dtype`, which makes it `float32` by default. This PR adds support for export dtypes of `["float16", "bfloat16", "float32"]`.
This PR integrates the changes in pytorch#1970 to compiler toolkit (applying `joint_ac_pass` on the joint graph graph to tag nodes based on `reshard_after_forward` flag) Also did some refactor for applying graph passes in compiler toolkit experiments. We will have two kinds of passes 1. joint_custom_passes: these are passes to be applied on the captured joint graph before partitioner. By default we `validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass` 2. compiler_passes: there are passes to be applied on partitioned fwd and bwd graphs as backend optimizations. By default there is none. We can indicate `autobucketing_reordering_pass` and `regional_inductor_pass` using configs.
…ytorch#2056) This PR integrates the manual bucketing pass (transformer block bucketing) added in SimpleFSDP experiment (pytorch#1881) to compiler toolkit So now compiler toolkit can also run manual bucketing pass by specifying the config ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` Also updated README and integration test to include the newly ported pass
…h only (pytorch#2018) Addressing following issues in this PR- Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2049 * __->__ pytorch#2029 ## Summary This PR adds `scripts/loss_compare.py` for comparing training losses between different git commits and/or training configurations. ## Key Features - Commit Comparison: Compare losses between two different git commits with deterministic training - Configuration Comparison: Compare different training configurations on the same commit - Reproducibility: Automatically enables deterministic mode and seed checkpointing for reproducible comparisons - Real-time Output: Streams training output to both console and log files during execution - Statistical Analysis: Generates step-by-step loss comparisons and summary statistics - CI Testing: Includes --assert-equal flag for automated testing to verify identical losses ## Usage Examples #### Compare two commits ``` python3 ./scripts/loss_compare.py main my_branch ``` #### Compare two commits with custom configuration ``` python3 ./scripts/loss_compare.py main my_branch \ --baseline-config="./custom.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ ``` #### Compare different parallelization strategies on same commit ``` python3 ./scripts/loss_compare.py . . \ --baseline-config="./llama3_8b.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ --test-options="--parallelism.tensor_parallel_degree=1" \ ``` #### Assert equality for CI testing ``` python3 ./scripts/loss_compare.py main my_branch --assert-equal ``` ## Real Use Cases Compare full dtensor simple fsdp with fsdp2: ``` python3 scripts/loss_compare.py . . \ --baseline-options='--activation_checkpoint.mode="none"' \ --test-train-file='torchtitan.experiments.full_dtensor.train' \ --test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none"' \ --assert-equal --no-seed-checkpoint [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ```
All tests in experiments are broken due to the `gpu_arch_type` field added in pytorch#2018.
…#2064) Adding CudaGraph pass (pytorch#2050) would require some custom logic in Trainer's close() method. So we create a Trainer subclass in compiler toolkit
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2063 * __->__ pytorch#2062 This will prevent errors when later doing git checkout
## Features - [x] Support SimpleFSDP and TP - [x] Support static input indices to reduce copy - [x] Support memory reuse to reduce memory consumption - [x] Cleanup cudagraph when training finishes to avoid nccl hang from destroy_process_group Command: ``` NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph ``` Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl + cudagraphs + expandable segments result in IMA. pytorch/pytorch#158029 [trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json) ## Result **Numerics:** Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B. **Performance:** <img width="560" height="90" alt="image" src="https://github.com/user-attachments/assets/9d54c461-0eb1-4f7e-9652-3d52043ad74f" /> Raw log: [llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190), [llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416) **Memory:** On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB). A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example): - Start: 161 GiB - \+ use the same stream for warmup and graph capture of both fwd and bwd: 160 GiB - \+ warmup in cudagraph memory pool instead of eager memory pool: 153 GiB **static input copy:** On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly. ## Followup PR In the followup PR, I will enable fx graph partition for deepseek v3 pytorch/pytorch#165945.
This PR fixes access to args; it's an attribute, not a variable in the scope. The method itself though would not be used because `should_check_address` seems to be always `False` and there doesn't seem to be a command line argument for it. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
# Context Reference PR: huggingface#1 This PR enables: - Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP (and the combinations between them). The following models were tested: - `meta-llama/Llama-3.2-1B` - `microsoft/phi-2` - `Qwen/Qwen2.5-7B` - `mistralai/Mistral-7B-v0.1` - `ByteDance-Seed/Seed-Coder-8B-Instruct` - `Qwen/Qwen3-4B-Instruct-2507` - `arcee-ai/AFM-4.5B` - `ibm-granite/granite-3b-code-base-2k` - `baidu/ERNIE-4.5-0.3B-Base-PT` - `kyutai/helium-1-preview-2b` - `allenai/OLMo-7B-hf` - `mistralai/Ministral-8B-Instruct-2410` - Patching HF models weights initialisation. Without this, the the `loss` and `grad_norm` starts very high # Usage - Requirements `transformers==4.57.1` - Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` ```diff ... [model] - name = "llama3" + name = "transformers_backend" flavor = "debugmodel" hf_assets_path = "./tests/assets/tokenizer" +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" ... ``` - Train: `LOG_RANK=7 CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` <img width="1334" height="453" alt="image" src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c" /> # Testing methodology <img width="2672" height="2018" alt="image" src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7" /> - Following the [converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md) guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other //-ism>` - More precisely, the `test_hf_integration.py`is going to do: ```bash results/ |_ meta-llama |_ Llama-3.2-1B |_ debugmodel/ |_ seed_checkpoint/ |_ config.toml |_ seed.slurm |_ step-0/ |_ .... |_ fsdp2_tp1_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ fsdp2_tp2_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp1_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log` |_ full/ ... ``` - Here is the grid search to test the HF modelling ```shell #!/usr/bin/bash model_names=( "meta-llama/Llama-3.2-1B" "microsoft/phi-2" "Qwen/Qwen2.5-7B" "mistralai/Mistral-7B-v0.1" "ByteDance-Seed/Seed-Coder-8B-Instruct" "Qwen/Qwen3-4B-Instruct-2507" "arcee-ai/AFM-4.5B" "ibm-granite/granite-3b-code-base-2k" "baidu/ERNIE-4.5-0.3B-Base-PT" "kyutai/helium-1-preview-2b" "allenai/OLMo-7B-hf" "mistralai/Ministral-8B-Instruct-2410" ) for model_name in "${model_names[@]}"; do rm -rf slurm_results/${model_name} python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do echo "Waiting for seed checkpoint from ${model_name} to complete ..." sleep 1 done python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high echo "================" done ``` # Further tasks - Moe (handle in PR huggingface#3) - Missing `build_optimizers_with_moe_load_balancing` support for MoE - Missing TP/PP/EP supports for MoE - When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. (issue is tracked in huggingface#4) - Add convergence tests to CI by doing tiny model + gloo backend (once PP is bitwise matching) - the HF modeling has lower MFU than Torchtitan MFU - NOTE: `import torch._dynamo.config; torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for graph when using `torch.compile` and `activation checkpointing`
**Summary** This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace `use_flex_attn` with `attn_type` (either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the attention module calls a compiled `varlen_attn` defined [here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py). **Testing** Ran loss and performance tests against flex attention. Loss is on par. <img width="947" height="505" alt="Screenshot 2025-11-19 at 3 24 26 PM" src="https://github.com/user-attachments/assets/d85dfc09-4f5e-4f82-abc9-49b870b34990" /> Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into `flash_attention_forward`/`flash_attention_backward` today). | | Varlen | Flex | | :---: | :------ | :---: | | Forward | 774us 357ns | 722us 317ns | | Backward | 1ms 955us 916ns | 1ms 558us 747ns |
Rebased LLEP (Least-Loaded Expert Parallelism) from the old phuc/kimi_k2_with_autotune_llep_optimized_llep branch onto the latest upstream-2026-10-02 to resolve 86 merge conflicts caused by 1000+ upstream commits since the original branch point. New LLEP core files: - torchtitan/distributed/llep.py: dispatch/combine with LPT routing - torchtitan/distributed/llep_autotune.py: hyperparameter autotuning - torchtitan/distributed/llep_kernels.py: Triton kernels Integration points (surgical changes to upstream files): - moe.py: LLEPConfig, fast_init_*, llep_state in GroupedExperts.forward - expert_parallel.py: ExpertParallelLLEP class - job_config.py: LLEP config dataclass - llama4/parallelize.py: LLEP EP selection logic - deepseek_v3/__init__.py: LLEP model variants - train.py: LLEP autotune at startup
The [llep] TOML section (enabled, max_tokens_factor, etc.) was not being applied to moe_args in DeepSeekV3ModelArgs.update_from_config(), so LLEP was never actually activated. This caused OOM on the imbalanced GPU since standard EP doesn't balance memory across ranks.
- Add debugmodel_ep8_llep_3b flavor (1.75B params, 64 experts, EP=8) and debug_model_ep8_llep_3b.toml for benchmarking on new upstream (the 9.5B config OOMs due to upstream memory regression) - Copy 3 missing multinode Kimi K2 LLEP+Muon TOML configs from old branch
Keep only essential files: - docs/llep.md (main documentation) - debug_model_llep.toml (2-GPU smoke test) - debug_model_ep8_llep_3b.toml (8-GPU benchmark) - test_llep_toml_override.toml (unit test config) - debugmodel_llep, debugmodel_ep8_llep_3b model flavors Removed: optimization report, pr008 cleanup doc, loss comparison scripts, baseline/stresstest/mini_kimi/kimi_k2/multinode TOMLs, and their corresponding model flavor definitions.
- Get EP process group from experts' DTensor device_mesh["ep"] instead of non-existent MoE._ep_group attribute - Use moe_module.use_llep and moe_module._llep_config instead of non-existent _llep_enabled/_llep_max_tokens_factor attributes
When both expert_parallel_comm_backend="deepep" and llep.enabled=true, uses DeepEP for balanced steps and falls back to LLEP for imbalanced steps based on adaptive_threshold. New classes: - DeepEPLLEPExpertParallel: per-step dispatch/combine hook that checks imbalance ratio and routes to DeepEP or LLEP path accordingly - DeepEPLLEPMoE: MoE module that passes 5-tuple routing info to experts and handles async combine overlap with shared_experts Wiring: args.py derives moe_impl="deepep_llep" when both flags set, build_moe() creates DeepEPLLEPMoE, apply_moe_ep_tp() installs the adaptive expert parallel hooks.
Replace the 3B debugmodel_ep8_llep_3b (too small for LLEP to help) with the 9.5B debugmodel_ep8_llep (dim=2048, moe_inter_dim=1536, 16 layers, lbs=8) that stresses GPU memory and shows LLEP's benefit. Benchmark results on 8xB200 (steps 5-20): - Speed: +10.9% mean TPS (26,370 vs 23,780) - Memory: 7 GiB spread vs 42 GiB, max 82% vs 97% - Without LLEP at lbs=10: OOM. With LLEP: runs fine. - Loss correctness: <0.001 diff by step 130
DeepEPLLEPMoE.forward() was identical to DeepEPMoE.forward(). The behavioral difference comes entirely from which ExpertParallel hooks get installed (DeepEPExpertParallel vs DeepEPLLEPExpertParallel), not from the MoE class itself.
4-config comparison on 9.5B model, 8xB200: - LLEP (standard): 26,250 TPS, 6 GiB spread - DeepEP+LLEP (adaptive): 25,820 TPS, 7 GiB spread - Standard EP: 21,940 TPS, 41 GiB spread - DeepEP only: 19,640 TPS, 52 GiB spread
The fused_silu_gate Triton kernel (fwd/bwd), FusedSiLUGate autograd class, and silu_gate_reference were not wired into the training path (llep.py never imported them). Delete the dead code from llep_kernels.py, remove test_llep_correctness.py (all tests were for the deleted kernel), and update docs/llep.md to reflect current state.
Move llep.py, llep_kernels.py, llep_autotune.py into torchtitan/distributed/llep/ and docs/llep.md as its README.md. All external imports (from torchtitan.distributed.llep import ...) remain unchanged since llep/__init__.py preserves the public API. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rmal_ Delete fast_init_trunc_normal_ and fast_init_normal_ from moe.py since upstream utils.py already provides equivalent trunc_normal_ and normal_. Wrap init_weights() call in test with torch.no_grad() to match how the training pipeline invokes it. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
feat: Least-Loaded Expert Parallelism with new upstream
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.