Draft
Conversation
We should be able to control what passes to run in the compiler. This PR uses the config compile.passes to indicate in a list of graph passes to apply on the captured gm. By default, no pass is applied. Users can specify what passes to apply. Currently there are `autobucketing_reordering_pass` and `regional_inductor_pass`. ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` Also updated CI to include this new config
After some offline discussion, we've concluded that life would be easier if we can put simplefsdp's checkpoint logic for `reshard_after_forward` to compiler. The ac annotation part is borrowed form AP: [LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69). **Trace and Loss Check** (all with torch.compile enable) reshard_after_fwd = False 1. SAC + llama3 ([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json)) <img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM" src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a" /> <img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM" src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c" /> 2. Full AC + llama3 [(trace)]() <img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM" src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548" /> <img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM" src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5" /> 3. No AC + llama3 [[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)] <img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM" src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0" /> <img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM" src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af" /> reshard_after_fwd = True 1. SAC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json)) <img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM" src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4" /> 2. Full AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json)) <img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM" src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca" /> 3. No AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json)) <img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM" src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc" />
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#2002 * __->__ pytorch#2001 Add typing, credit to Claude.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ pytorch#2013 When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. **This is a reland PR of pytorch#2002. The previous one was broken during rebase.**
Summary: - we need to pass the global rank information to pytorch so that the pg name can include the pg information - this is necessary to differentiate the default pg's on different replicas - these need to different because flight recorder matches collectives based on pg name as well - add ft training to experiments folder, we'll move remaining pieces of ft to this gradually but make new features only available through this folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986). * pytorch#1988 * pytorch#1987 * __->__ pytorch#1986 Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2012 * __->__ pytorch#2011 It is not correct as JobConfig has changed.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ pytorch#2012 * pytorch#2011 Summary: The current configuration validation requires torchx and GPUs. It can waste time, resources, ane engery. Polar bears are crying. Let's fix this by providing a dry run mode. This PR doesn't verify everything. In theory, we should be able to verify parallelisms settings as well. This PR is just a start but it at least can let us catch the typos quickly.
As titled. `_clear_traced_params_buffers` is no longer being used as we have switched the dynamo graph capture API.
…nly (pytorch#2016) Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment)
…& push to Main branch only" (pytorch#2017) Reverts PR: pytorch#2016 Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment) Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
…2015) This PR adds the utils to automatically check the training numerics (losses, grad norms) of two runs to verify if they have bitwise equivalence. The added script triggers two runs with user defined configs. Then it loads metrics saved during training and compare the numerics to verify bitwise equivalence. Currently we check for losses and grad norms during training steps For example, we want to compare the numerics between compiler toolkit with aot_eager backend and eager on llama3-8B. ``` python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2 ``` It'll run `simple_fsdp` experiment without `torch.compile` as the eager baseline, and `compile_toolkit` experiment as the compiled run. Then it compares the training numerics of these two runs to verify bitwise equivalence. When it is bitwise equivalent, we'll see the following output ``` Starting training: simple_fsdp.llama3 ✓ Training completed: simple_fsdp.llama3 Starting training: compiler_toolkit.llama3 ✓ Training completed: compiler_toolkit.llama3 ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ SUCCESS: All metrics are bitwise equivalent ``` Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so that we can guard working parallelism combinations that already have bitwise equivalence in CI.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * pytorch#2030 * pytorch#2028 * pytorch#2027 * __->__ pytorch#2026 As title
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * pytorch#2030 * pytorch#2028 * __->__ pytorch#2027 * pytorch#2026 Dry run mode works but it doesn't exit gracefully for all cases. This PR fixes it ``` DRY_RUN=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --activation_checkpoint.mode="none" --debug.deterministic --debug.seed=42 ```
…h#2030) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2029 * __->__ pytorch#2030 The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. Also merge pytorch#2028 to this PR. Something wrong with ghstack.
Fixes the badge in the `README.md` file
Before: <img width="978" height="93" alt="image" src="https://github.com/user-attachments/assets/48dc39d9-e897-4396-ac62-025574303403" /> After: <img width="1318" height="82" alt="image" src="https://github.com/user-attachments/assets/47b4771a-aaf9-4f61-80bc-757f3a08c1d2" />
This PR adds support for aten-level manual bucketing in SimpleFSDP+`aot_eager` backend. Dependent on PyTorch [PR](pytorch/pytorch#165487) TODO List: - [ ] We should have better way of handling region info other than a list of str FQNs in current `manual_bucketed_modules`. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad ) - [ ] Currently, the reordering happens under the hood and overlap with last/next compute. We should allow users to specify which module they want to reorder. - [ ] Loss difference on multi-node training - [ ] DSV3 manual bucketing I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR. 1. Performance (FSDP2 under eager mode, SimpleFSDP uses `aot_eager` backend) **Llama 3-8B** * Performance (All Batch_size = 1). (The slower TPS on Single Node is sort of as expected, since FSDP2 handles copy-in/out in two different streams, whereas SimpleFSDP handles copy-in/out in the same stream) |Node| Method | Parallelism | Memory | TPS | Trace| |---------|---------|-----------|----------|------|------| |1-Node (8H100)|SimpleFSDP | FSDP=8| 40.96GiB(43.12%) | 7,227| [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-48-48_rank0_trace.json)| |1-Node (8H100)|FSDP2-eager| FSDP=8| 47.82GiB(50.35%) | 7,380 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-54-14_rank0_trace.json)| |8-Node (64H100)|SimpleFSDP| FSDP=64 | 29.37GiB | 4,984| | |8-Node (64H100)|FSDP2| FSDP=64 | 31.41GiB |5,097 | | |1-Node (8H100)|SimpleFSDP| FSDP=4 TP=2 | 28.28GiB(29.77%) | 5,881 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-18-00-18_rank0_trace.json) | |1-Node (8H100)|FSDP2| FSDP=4 TP=2 | 35.33GiB(37.20%) | 5,898 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-15-35-47_rank0_trace.json) | |8-Node (64H100)|SimpleFSDP| FSDP=8 TP=8 | ||| |8-Node (64H100)|FSDP2| FSDP=8 TP=8 | ||| Example SimpleFSDP 1D overlapping trace: <img width="1127" height="127" alt="Screenshot 2025-10-16 at 10 49 55 AM" src="https://github.com/user-attachments/assets/2d9e3ff8-8e9b-40a7-a666-3c0a0975186e" /> Example SimpleFSDP 2D overlapping trace: <img width="1162" height="166" alt="Screenshot 2025-10-26 at 6 00 51 PM" src="https://github.com/user-attachments/assets/bc5cc031-5b6c-4e4d-a9da-70c43114f49a" /> - Bitwise Loss: FSDP-only: <img width="1266" height="837" alt="Screenshot 2025-10-17 at 10 41 56 AM" src="https://github.com/user-attachments/assets/30f83d95-1eca-4f10-9e7e-47c45278cd8d" /> FSDP+TP: <img width="1259" height="808" alt="Screenshot 2025-10-26 at 9 03 58 PM" src="https://github.com/user-attachments/assets/b75b452b-adb9-4078-9412-ee9e584ffe15" />
The current `convert_to_hf.py` does not support `export_dtype`, which makes it `float32` by default. This PR adds support for export dtypes of `["float16", "bfloat16", "float32"]`.
This PR integrates the changes in pytorch#1970 to compiler toolkit (applying `joint_ac_pass` on the joint graph graph to tag nodes based on `reshard_after_forward` flag) Also did some refactor for applying graph passes in compiler toolkit experiments. We will have two kinds of passes 1. joint_custom_passes: these are passes to be applied on the captured joint graph before partitioner. By default we `validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass` 2. compiler_passes: there are passes to be applied on partitioned fwd and bwd graphs as backend optimizations. By default there is none. We can indicate `autobucketing_reordering_pass` and `regional_inductor_pass` using configs.
…ytorch#2056) This PR integrates the manual bucketing pass (transformer block bucketing) added in SimpleFSDP experiment (pytorch#1881) to compiler toolkit So now compiler toolkit can also run manual bucketing pass by specifying the config ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` Also updated README and integration test to include the newly ported pass
…h only (pytorch#2018) Addressing following issues in this PR- Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. Refactor Torchtitan test run to address older PR comment pytorch#1786 (comment)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2049 * __->__ pytorch#2029 ## Summary This PR adds `scripts/loss_compare.py` for comparing training losses between different git commits and/or training configurations. ## Key Features - Commit Comparison: Compare losses between two different git commits with deterministic training - Configuration Comparison: Compare different training configurations on the same commit - Reproducibility: Automatically enables deterministic mode and seed checkpointing for reproducible comparisons - Real-time Output: Streams training output to both console and log files during execution - Statistical Analysis: Generates step-by-step loss comparisons and summary statistics - CI Testing: Includes --assert-equal flag for automated testing to verify identical losses ## Usage Examples #### Compare two commits ``` python3 ./scripts/loss_compare.py main my_branch ``` #### Compare two commits with custom configuration ``` python3 ./scripts/loss_compare.py main my_branch \ --baseline-config="./custom.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ ``` #### Compare different parallelization strategies on same commit ``` python3 ./scripts/loss_compare.py . . \ --baseline-config="./llama3_8b.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ --test-options="--parallelism.tensor_parallel_degree=1" \ ``` #### Assert equality for CI testing ``` python3 ./scripts/loss_compare.py main my_branch --assert-equal ``` ## Real Use Cases Compare full dtensor simple fsdp with fsdp2: ``` python3 scripts/loss_compare.py . . \ --baseline-options='--activation_checkpoint.mode="none"' \ --test-train-file='torchtitan.experiments.full_dtensor.train' \ --test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none"' \ --assert-equal --no-seed-checkpoint [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ```
All tests in experiments are broken due to the `gpu_arch_type` field added in pytorch#2018.
…#2064) Adding CudaGraph pass (pytorch#2050) would require some custom logic in Trainer's close() method. So we create a Trainer subclass in compiler toolkit
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2063 * __->__ pytorch#2062 This will prevent errors when later doing git checkout
## Features - [x] Support SimpleFSDP and TP - [x] Support static input indices to reduce copy - [x] Support memory reuse to reduce memory consumption - [x] Cleanup cudagraph when training finishes to avoid nccl hang from destroy_process_group Command: ``` NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph ``` Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl + cudagraphs + expandable segments result in IMA. pytorch/pytorch#158029 [trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json) ## Result **Numerics:** Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B. **Performance:** <img width="560" height="90" alt="image" src="https://github.com/user-attachments/assets/9d54c461-0eb1-4f7e-9652-3d52043ad74f" /> Raw log: [llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190), [llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416) **Memory:** On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB). A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example): - Start: 161 GiB - \+ use the same stream for warmup and graph capture of both fwd and bwd: 160 GiB - \+ warmup in cudagraph memory pool instead of eager memory pool: 153 GiB **static input copy:** On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly. ## Followup PR In the followup PR, I will enable fx graph partition for deepseek v3 pytorch/pytorch#165945.
This PR fixes access to args; it's an attribute, not a variable in the scope. The method itself though would not be used because `should_check_address` seems to be always `False` and there doesn't seem to be a command line argument for it. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
# Context Reference PR: huggingface#1 This PR enables: - Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP (and the combinations between them). The following models were tested: - `meta-llama/Llama-3.2-1B` - `microsoft/phi-2` - `Qwen/Qwen2.5-7B` - `mistralai/Mistral-7B-v0.1` - `ByteDance-Seed/Seed-Coder-8B-Instruct` - `Qwen/Qwen3-4B-Instruct-2507` - `arcee-ai/AFM-4.5B` - `ibm-granite/granite-3b-code-base-2k` - `baidu/ERNIE-4.5-0.3B-Base-PT` - `kyutai/helium-1-preview-2b` - `allenai/OLMo-7B-hf` - `mistralai/Ministral-8B-Instruct-2410` - Patching HF models weights initialisation. Without this, the the `loss` and `grad_norm` starts very high # Usage - Requirements `transformers==4.57.1` - Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` ```diff ... [model] - name = "llama3" + name = "transformers_backend" flavor = "debugmodel" hf_assets_path = "./tests/assets/tokenizer" +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" ... ``` - Train: `LOG_RANK=7 CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` <img width="1334" height="453" alt="image" src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c" /> # Testing methodology <img width="2672" height="2018" alt="image" src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7" /> - Following the [converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md) guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other //-ism>` - More precisely, the `test_hf_integration.py`is going to do: ```bash results/ |_ meta-llama |_ Llama-3.2-1B |_ debugmodel/ |_ seed_checkpoint/ |_ config.toml |_ seed.slurm |_ step-0/ |_ .... |_ fsdp2_tp1_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ fsdp2_tp2_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp1_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log` |_ full/ ... ``` - Here is the grid search to test the HF modelling ```shell #!/usr/bin/bash model_names=( "meta-llama/Llama-3.2-1B" "microsoft/phi-2" "Qwen/Qwen2.5-7B" "mistralai/Mistral-7B-v0.1" "ByteDance-Seed/Seed-Coder-8B-Instruct" "Qwen/Qwen3-4B-Instruct-2507" "arcee-ai/AFM-4.5B" "ibm-granite/granite-3b-code-base-2k" "baidu/ERNIE-4.5-0.3B-Base-PT" "kyutai/helium-1-preview-2b" "allenai/OLMo-7B-hf" "mistralai/Ministral-8B-Instruct-2410" ) for model_name in "${model_names[@]}"; do rm -rf slurm_results/${model_name} python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do echo "Waiting for seed checkpoint from ${model_name} to complete ..." sleep 1 done python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high echo "================" done ``` # Further tasks - Moe (handle in PR huggingface#3) - Missing `build_optimizers_with_moe_load_balancing` support for MoE - Missing TP/PP/EP supports for MoE - When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. (issue is tracked in huggingface#4) - Add convergence tests to CI by doing tiny model + gloo backend (once PP is bitwise matching) - the HF modeling has lower MFU than Torchtitan MFU - NOTE: `import torch._dynamo.config; torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for graph when using `torch.compile` and `activation checkpointing`
**Summary** This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace `use_flex_attn` with `attn_type` (either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the attention module calls a compiled `varlen_attn` defined [here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py). **Testing** Ran loss and performance tests against flex attention. Loss is on par. <img width="947" height="505" alt="Screenshot 2025-11-19 at 3 24 26 PM" src="https://github.com/user-attachments/assets/d85dfc09-4f5e-4f82-abc9-49b870b34990" /> Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into `flash_attention_forward`/`flash_attention_backward` today). | | Varlen | Flex | | :---: | :------ | :---: | | Forward | 774us 357ns | 722us 317ns | | Backward | 1ms 955us 916ns | 1ms 558us 747ns |
Added troubleshooting tip for missing libnvshmem_host.so.
* Support mxfp8 on gfx950. It depends on TorchAO (pytorch/ao#3620).
Update README with libnvshmem_host.so troubleshooting
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * pytorch#2260 * pytorch#2246 * __->__ pytorch#2245 As title
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * pytorch#2260 * __->__ pytorch#2246 * pytorch#2245
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ pytorch#2260 * pytorch#2246 * pytorch#2245 This API is not needed any more. pyrefly will give us a warning.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * pytorch#2145 * __->__ pytorch#2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ pytorch#2145 Summary: Continue the previous PR, this PR enable FlexAttention + CP for llama3. FlexCP will use PTRRLoadBalancer. Note that this PR requires pytorch/pytorch#170201
Previously, individual experts are marked as `Replicate` in EP dimension in global `global_device_mesh`. Local experts are first created on `global_device_mesh` and are turned into a 2d tensor using `squeeze(0)`, which only removes the extra dimension, but the remaining metadata `Replicate` is still there. The wrong metadata results in bug when DCP saves `DTensor`. This PR fixes this bug by: 1. Use a sub-mesh that excludes expert dimension, i.e., dim 0. 2. When sub-mesh is empty, use plain tensor instead of `DTensor`.
Added installation of transformers package and updated sbatch script instructions.
`is_causal` flag has been deprecated in `varlen_attn`, use `window_size = [-1, 0]` instead, see [this PR](pytorch/pytorch#172245) *Test* <img width="1225" height="496" alt="Screenshot 2026-01-21 at 11 55 59 AM" src="https://github.com/user-attachments/assets/125c56af-76ed-4f6f-a433-da4d9a631dab" />
In this PR we added ROCm CI support for simple fsdp experiments test.
…nfigs Memory Tracking Tools: - Add DetailedMemoryTracker for per-phase memory tracking (before_forward, after_forward_backward, after_optimizer, step_end) - Add CUDAMemoryTracker for PyTorch vs nvidia-smi memory comparison - Add AggressiveMemoryManager for CUDA fragmentation reduction with modes: minimal, balanced, aggressive, maximum BF16 Optimizer States: - Add BF16StateOptimizersContainer wrapper for pre-initializing optimizer states in bfloat16 before first step (50% memory savings) - Add preinit_optimizer_states_bf16() to allocate exp_avg/exp_avg_sq in param dtype from the start, avoiding fp32 allocation spike - Fix device mismatch bug: state["step"] tensor now created on param device New Config Options: - optimizer.state_dtype: "float32" | "bfloat16" - training.enable_detailed_memory_tracking: bool - training.clear_cache_between_steps: bool - training.skip_optimizer_step: bool - training.aggressive_memory_mode: "minimal" | "balanced" | "aggressive" | "maximum" - training.aggressive_memory_verbose: bool Train Loop Integration: - Initialize memory trackers in Trainer.__init__ - Call tracking at forward_backward_step and train_step phases - Call aggressive memory manager at post_backward, post_optimizer, step_end - Pre-initialize BF16 optimizer states before training loop Configs Added: - qwen3_30b_a3b_memory_test.toml: Test config for memory features - kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml: 12-node production config - kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml: 36-node HSDP config
…reporting Config Options Added: [parallelism] - fsdp_disable_prefetch: Disable FSDP forward/backward prefetching (reduces memory at cost of less communication overlap) [debug] - enable_nan_tracker: Enable lightweight NaN/Inf tracking to find where NaN first appears in the model - nan_tracker_verbose: Print stats for every layer (very verbose) Enhanced Metrics: - Add nvidia-smi memory reporting to DeviceMemStats for verification - Add _get_nvidia_smi_memory() method to DeviceMemoryMonitor - Handles CUDA_VISIBLE_DEVICES remapping for SLURM environments FSDP Prefetch Control: - Add disable_prefetch parameter to apply_fsdp() in llama4 - Wire up fsdp_disable_prefetch config to apply_fsdp calls in: llama4, deepseek_v3, qwen3, gpt_oss Test Configs: - qwen3_30b_a3b_memory_test.toml: Added [debug] section with nan_tracker Note: fsdp_bucket_cap_mb not added as it's not supported by FSDP2 API
Allow fsdp_reshard_after_forward to accept an integer N for partial resharding to N-GPU groups. This reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. Use N=8 for intra-node resharding (fast NVLink communication). N must be a factor of the FSDP shard world size. Example: fsdp_reshard_after_forward = 8 Changes: - config/job_config.py: Update type to Literal[...] | int - llama4/parallelize.py: Handle int values in apply_fsdp()
Standalone diagnostic tool that visualizes GPU allocation across parallelism dimensions (DP, PP, TP, CP, EP). Only runs on rank 0 at initialization - no impact on training performance. Features: - Mesh structure visualization - GPU allocation grid - Expert parallel group allocation - Context parallel group allocation - FSDP sharding visualization Integrated into train.py to automatically log at startup.
…/torchtitan into upstream-2026-24-01
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.