Skip to content

Merge upstream 2026 10 02#60

Merged
dmahan93 merged 199 commits intodev-updated-againfrom
merge-upstream-2026-10-02
Mar 13, 2026
Merged

Merge upstream 2026 10 02#60
dmahan93 merged 199 commits intodev-updated-againfrom
merge-upstream-2026-10-02

Conversation

@dmahan93
Copy link

No description provided.

mori360 and others added 30 commits November 10, 2025 10:14
We should be able to control what passes to run in the compiler. This PR
uses the config compile.passes to indicate in a list of graph passes to
apply on the captured gm.

By default, no pass is applied. Users can specify what passes to apply.

Currently there are `autobucketing_reordering_pass` and
`regional_inductor_pass`.

```
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor
```

Also updated CI to include this new config
After some offline discussion, we've concluded that life would be easier
if we can put simplefsdp's checkpoint logic for `reshard_after_forward`
to compiler. The ac annotation part is borrowed form AP:
[LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69).

**Trace and Loss Check** (all with torch.compile enable)

reshard_after_fwd = False
1. SAC + llama3
([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json))
<img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM"
src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a"
/>

<img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM"
src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c"
/>


2. Full AC + llama3 [(trace)]()

<img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM"
src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548"
/>

<img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM"
src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5"
/>


3. No AC + llama3
[[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)]

<img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM"
src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0"
/>

<img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM"
src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af"
/>


reshard_after_fwd = True

1. SAC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json))

<img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM"
src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4"
/>


2. Full AC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json))

<img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM"
src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca"
/>


3. No AC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json))


<img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM"
src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc"
/>
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#2002
* __->__ pytorch#2001

Add typing, credit to Claude.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* __->__ pytorch#2013

When full_dtensor is True, the compute_placement will be preserved. This
means that `to_local()` won't be called for fsdp only case. nD
parallelism case (fsdp + tp) will error out as we have not implemented
this case.

This argument doesn't affect the current simple_fsdp. We have verified
`full_dtensor=True` case with the full dtensor skleton PR, which will be
published once it is ready.

**This is a reland PR of
pytorch#2002. The previous one was
broken during rebase.**
Summary:
- we need to pass the global rank information to pytorch so that the pg
name can include the pg information
- this is necessary to differentiate the default pg's on different
replicas
- these need to different because flight recorder matches collectives
based on pg name as well
- add ft training to experiments folder, we'll move remaining pieces of
ft to this gradually but make new features only available through this
folder

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986).
* pytorch#1988
* pytorch#1987
* __->__ pytorch#1986

Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2012
* __->__ pytorch#2011

It is not correct as JobConfig has changed.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* __->__ pytorch#2012
* pytorch#2011

Summary:
The current configuration validation requires torchx and GPUs. It can
waste time, resources, ane engery. Polar bears are crying. Let's fix
this by providing a dry run mode. This PR doesn't verify everything. In
theory, we should be able to verify parallelisms settings as well. This
PR is just a start but it at least can let us catch the typos quickly.
As titled. `_clear_traced_params_buffers` is no longer being used as we
have switched the dynamo graph capture API.
…nly (pytorch#2016)

Addressing following issues in this PR-

- Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
- Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)
…& push to Main branch only" (pytorch#2017)

Reverts PR: pytorch#2016
Addressing following issues in this PR-
- Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
- Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
…2015)

This PR adds the utils to automatically check the training numerics
(losses, grad norms) of two runs to verify if they have bitwise
equivalence.

The added script triggers two runs with user defined configs. Then it
loads metrics saved during training and compare the numerics to verify
bitwise equivalence. Currently we check for losses and grad norms during
training steps

For example, we want to compare the numerics between compiler toolkit
with aot_eager backend and eager on llama3-8B.
```
python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2
```
It'll run `simple_fsdp` experiment without `torch.compile` as the eager
baseline, and `compile_toolkit` experiment as the compiled run. Then it
compares the training numerics of these two runs to verify bitwise
equivalence.

When it is bitwise equivalent, we'll see the following output
```
Starting training: simple_fsdp.llama3
✓ Training completed: simple_fsdp.llama3

Starting training: compiler_toolkit.llama3
✓ Training completed: compiler_toolkit.llama3
  ✓ PASS: All 11 steps match exactly (bitwise equivalent)
  ✓ PASS: All 11 steps match exactly (bitwise equivalent)
✓ SUCCESS: All metrics are bitwise equivalent
```

Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so
that we can guard working parallelism combinations that already have
bitwise equivalence in CI.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2029
* pytorch#2030
* pytorch#2028
* __->__ pytorch#2027
* pytorch#2026

Dry run mode works but it doesn't exit gracefully for all cases. This PR
fixes it

```
DRY_RUN=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh   --training.steps=10 --activation_checkpoint.mode="none"
--debug.deterministic --debug.seed=42
```
…h#2030)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2029
* __->__ pytorch#2030

The current CompileModule will result in an "inner" prefix for
everything. This
PR fixes it by overloading the methods.

Also merge pytorch#2028 to this PR.
Something wrong with ghstack.
This PR adds support for aten-level manual bucketing in
SimpleFSDP+`aot_eager` backend. Dependent on PyTorch
[PR](pytorch/pytorch#165487)

TODO List:
- [ ] We should have better way of handling region info other than a
list of str FQNs in current `manual_bucketed_modules`. It would be very
easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )
- [ ] Currently, the reordering happens under the hood and overlap with
last/next compute. We should allow users to specify which module they
want to reorder.
- [ ] Loss difference on multi-node training
- [ ] DSV3 manual bucketing

I'll address the TODO items in follow up PRs. Let's start with this
simple FSDP+TP+llama3 PR.

1. Performance (FSDP2 under eager mode, SimpleFSDP uses `aot_eager`
backend)

**Llama 3-8B**

* Performance (All Batch_size = 1). (The slower TPS on Single Node is
sort of as expected, since FSDP2 handles copy-in/out in two different
streams, whereas SimpleFSDP handles copy-in/out in the same stream)

|Node| Method | Parallelism | Memory | TPS | Trace|
|---------|---------|-----------|----------|------|------|
|1-Node (8H100)|SimpleFSDP | FSDP=8| 40.96GiB(43.12%) | 7,227|
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-48-48_rank0_trace.json)|
|1-Node (8H100)|FSDP2-eager| FSDP=8| 47.82GiB(50.35%) | 7,380 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-54-14_rank0_trace.json)|
|8-Node (64H100)|SimpleFSDP| FSDP=64  | 29.37GiB | 4,984| |
|8-Node (64H100)|FSDP2| FSDP=64 | 31.41GiB  |5,097 | |
|1-Node (8H100)|SimpleFSDP| FSDP=4 TP=2 | 28.28GiB(29.77%) | 5,881 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-18-00-18_rank0_trace.json)
|
|1-Node (8H100)|FSDP2| FSDP=4 TP=2 | 35.33GiB(37.20%) | 5,898 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-15-35-47_rank0_trace.json)
|
|8-Node (64H100)|SimpleFSDP| FSDP=8 TP=8  |   |||
|8-Node (64H100)|FSDP2| FSDP=8 TP=8 |   |||

Example SimpleFSDP 1D overlapping trace:

<img width="1127" height="127" alt="Screenshot 2025-10-16 at 10 49
55 AM"
src="https://github.com/user-attachments/assets/2d9e3ff8-8e9b-40a7-a666-3c0a0975186e"
/>

Example SimpleFSDP 2D overlapping trace:
<img width="1162" height="166" alt="Screenshot 2025-10-26 at 6 00 51 PM"
src="https://github.com/user-attachments/assets/bc5cc031-5b6c-4e4d-a9da-70c43114f49a"
/>


- Bitwise Loss:

FSDP-only:
<img width="1266" height="837" alt="Screenshot 2025-10-17 at 10 41
56 AM"
src="https://github.com/user-attachments/assets/30f83d95-1eca-4f10-9e7e-47c45278cd8d"
/>

FSDP+TP:
<img width="1259" height="808" alt="Screenshot 2025-10-26 at 9 03 58 PM"
src="https://github.com/user-attachments/assets/b75b452b-adb9-4078-9412-ee9e584ffe15"
/>
The current `convert_to_hf.py` does not support `export_dtype`, which
makes it `float32` by default. This PR adds support for export dtypes of
`["float16", "bfloat16", "float32"]`.
This PR integrates the changes in pytorch#1970 to compiler toolkit (applying
`joint_ac_pass` on the joint graph graph to tag nodes based on
`reshard_after_forward` flag)

Also did some refactor for applying graph passes in compiler toolkit
experiments. We will have two kinds of passes

1. joint_custom_passes: these are passes to be applied on the captured
joint graph before partitioner. By default we
`validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass`

2. compiler_passes: there are passes to be applied on partitioned fwd
and bwd graphs as backend optimizations. By default there is none. We
can indicate `autobucketing_reordering_pass` and
`regional_inductor_pass` using configs.
…ytorch#2056)

This PR integrates the manual bucketing pass (transformer block
bucketing) added in SimpleFSDP experiment (pytorch#1881) to compiler toolkit

So now compiler toolkit can also run manual bucketing pass by specifying
the config

```
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing
``` 

Also updated README and integration test to include the newly ported
pass
…h only (pytorch#2018)

Addressing following issues in this PR-

Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2049
* __->__ pytorch#2029


## Summary
This PR adds `scripts/loss_compare.py` for comparing training losses
between different git commits and/or training configurations.

## Key Features

- Commit Comparison: Compare losses between two different git commits
with deterministic training
- Configuration Comparison: Compare different training configurations on
the same commit
- Reproducibility: Automatically enables deterministic mode and seed
checkpointing for reproducible
  comparisons
- Real-time Output: Streams training output to both console and log
files during execution
- Statistical Analysis: Generates step-by-step loss comparisons and
summary statistics
- CI Testing: Includes --assert-equal flag for automated testing to
verify identical losses

## Usage Examples

#### Compare two commits
```
python3 ./scripts/loss_compare.py main my_branch
```
#### Compare two commits with custom configuration 
```
python3 ./scripts/loss_compare.py main my_branch \
--baseline-config="./custom.toml" 
--baseline-options="--parallelism.tensor_parallel_degree=2"  \
```

#### Compare different parallelization strategies on same commit
```
python3 ./scripts/loss_compare.py . . \
--baseline-config="./llama3_8b.toml" 
--baseline-options="--parallelism.tensor_parallel_degree=2" \
--test-options="--parallelism.tensor_parallel_degree=1" \
```

#### Assert equality for CI testing
```
python3 ./scripts/loss_compare.py main my_branch --assert-equal
```


## Real Use Cases
Compare full dtensor simple fsdp with fsdp2:
```
python3 scripts/loss_compare.py . . \
--baseline-options='--activation_checkpoint.mode="none"'  \
--test-train-file='torchtitan.experiments.full_dtensor.train' \ 
--test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none"'  \
 --assert-equal --no-seed-checkpoint


[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok
```
All tests in experiments are broken due to the `gpu_arch_type` field
added in pytorch#2018.
…#2064)

Adding CudaGraph pass (pytorch#2050)
would require some custom logic in Trainer's close() method.

So we create a Trainer subclass in compiler toolkit
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2063
* __->__ pytorch#2062

This will prevent errors when later doing git checkout
## Features
- [x] Support SimpleFSDP and TP
- [x] Support static input indices to reduce copy
- [x] Support memory reuse to reduce memory consumption
- [x] Cleanup cudagraph when training finishes to avoid nccl hang from
destroy_process_group

Command:
```
NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4  --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph
```


Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl +
cudagraphs + expandable segments result in IMA.
pytorch/pytorch#158029


[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json)

## Result

**Numerics:**
Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B
AND llama3.1-70B.

**Performance:**
<img width="560" height="90" alt="image"
src="https://github.com/user-attachments/assets/9d54c461-0eb1-4f7e-9652-3d52043ad74f"
/>

Raw log:
[llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190),
[llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416)

**Memory:**
On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs
153 GiB).

A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph
as an example):
- Start: 161 GiB
- \+ use the same stream for warmup and graph capture of both fwd and
bwd: 160 GiB
- \+ warmup in cudagraph memory pool instead of eager memory pool: 153
GiB


**static input copy:**
On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for
backward, we copy 1 tensor of 0.98 GB. This shows static input indices
is handled correctly.


## Followup PR
In the followup PR, I will enable fx graph partition for deepseek v3
pytorch/pytorch#165945.
This PR fixes access to args; it's an attribute, not a variable in the
scope.
The method itself though would not be used because
`should_check_address` seems to be always `False` and there doesn't seem
to be a command line argument for it.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
# Context
Reference PR: huggingface#1

This PR enables:
- Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP
(and the combinations between them). The following models were tested:
  - `meta-llama/Llama-3.2-1B`
  - `microsoft/phi-2`
  - `Qwen/Qwen2.5-7B`
  - `mistralai/Mistral-7B-v0.1`
  - `ByteDance-Seed/Seed-Coder-8B-Instruct`
  - `Qwen/Qwen3-4B-Instruct-2507`
  - `arcee-ai/AFM-4.5B`
  - `ibm-granite/granite-3b-code-base-2k`
  - `baidu/ERNIE-4.5-0.3B-Base-PT`
  - `kyutai/helium-1-preview-2b`
  - `allenai/OLMo-7B-hf`
  - `mistralai/Ministral-8B-Instruct-2410`
- Patching HF models weights initialisation. Without this, the the
`loss` and `grad_norm` starts very high

# Usage

- Requirements `transformers==4.57.1`
- Config:
`torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7
CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml
./run_train.sh
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config
--compile.enable`

<img width="1334" height="453" alt="image"
src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c"
/>

# Testing methodology

<img width="2672" height="2018" alt="image"
src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7"
/>

- Following the
[converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md)
guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other
//-ism>`
- More precisely, the `test_hf_integration.py`is going to do:

```bash
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
```
- Here is the grid search to test the HF modelling
```shell
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done
```

# Further tasks

- Moe (handle in PR huggingface#3)
	- Missing `build_optimizers_with_moe_load_balancing` support for MoE
	- Missing TP/PP/EP supports for MoE 
- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss`
and `grad_norm` not bitwise matching (but converging) while it is the
case with Torchtitan modeling. (issue is tracked in
huggingface#4)
- Add convergence tests to CI by doing tiny model + gloo backend (once
PP is bitwise matching)
- the HF modeling has lower MFU than Torchtitan MFU
- NOTE: `import torch._dynamo.config;
torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for
graph when using `torch.compile` and `activation checkpointing`
**Summary**
This PR adds variable length attention (varlen) support to the Llama 3
8b model in torchtitan. We replace `use_flex_attn` with `attn_type`
(either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the
attention module calls a compiled `varlen_attn` defined
[here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py).

**Testing**
Ran loss and performance tests against flex attention. Loss is on par.

<img width="947" height="505" alt="Screenshot 2025-11-19 at 3 24 26 PM"
src="https://github.com/user-attachments/assets/d85dfc09-4f5e-4f82-abc9-49b870b34990"
/>

Varlen is slightly slower than Flex due to the cuda kernel speeds
(varlen calls into `flash_attention_forward`/`flash_attention_backward`
today).


| | Varlen | Flex |
| :---: | :------ | :---: |
| Forward  | 774us 357ns | 722us 317ns  |
| Backward   | 1ms 955us 916ns  | 1ms 558us 747ns    |
jquesnelle and others added 25 commits February 18, 2026 02:17
Rebased LLEP (Least-Loaded Expert Parallelism) from the old
phuc/kimi_k2_with_autotune_llep_optimized_llep branch onto the
latest upstream-2026-10-02 to resolve 86 merge conflicts caused
by 1000+ upstream commits since the original branch point.

New LLEP core files:
- torchtitan/distributed/llep.py: dispatch/combine with LPT routing
- torchtitan/distributed/llep_autotune.py: hyperparameter autotuning
- torchtitan/distributed/llep_kernels.py: Triton kernels

Integration points (surgical changes to upstream files):
- moe.py: LLEPConfig, fast_init_*, llep_state in GroupedExperts.forward
- expert_parallel.py: ExpertParallelLLEP class
- job_config.py: LLEP config dataclass
- llama4/parallelize.py: LLEP EP selection logic
- deepseek_v3/__init__.py: LLEP model variants
- train.py: LLEP autotune at startup
The [llep] TOML section (enabled, max_tokens_factor, etc.) was not being
applied to moe_args in DeepSeekV3ModelArgs.update_from_config(), so LLEP
was never actually activated. This caused OOM on the imbalanced GPU since
standard EP doesn't balance memory across ranks.
- Add debugmodel_ep8_llep_3b flavor (1.75B params, 64 experts, EP=8)
  and debug_model_ep8_llep_3b.toml for benchmarking on new upstream
  (the 9.5B config OOMs due to upstream memory regression)
- Copy 3 missing multinode Kimi K2 LLEP+Muon TOML configs from old branch
Keep only essential files:
- docs/llep.md (main documentation)
- debug_model_llep.toml (2-GPU smoke test)
- debug_model_ep8_llep_3b.toml (8-GPU benchmark)
- test_llep_toml_override.toml (unit test config)
- debugmodel_llep, debugmodel_ep8_llep_3b model flavors

Removed: optimization report, pr008 cleanup doc, loss comparison
scripts, baseline/stresstest/mini_kimi/kimi_k2/multinode TOMLs,
and their corresponding model flavor definitions.
- Get EP process group from experts' DTensor device_mesh["ep"] instead
  of non-existent MoE._ep_group attribute
- Use moe_module.use_llep and moe_module._llep_config instead of
  non-existent _llep_enabled/_llep_max_tokens_factor attributes
When both expert_parallel_comm_backend="deepep" and llep.enabled=true,
uses DeepEP for balanced steps and falls back to LLEP for imbalanced
steps based on adaptive_threshold.

New classes:
- DeepEPLLEPExpertParallel: per-step dispatch/combine hook that checks
  imbalance ratio and routes to DeepEP or LLEP path accordingly
- DeepEPLLEPMoE: MoE module that passes 5-tuple routing info to experts
  and handles async combine overlap with shared_experts

Wiring: args.py derives moe_impl="deepep_llep" when both flags set,
build_moe() creates DeepEPLLEPMoE, apply_moe_ep_tp() installs the
adaptive expert parallel hooks.
Replace the 3B debugmodel_ep8_llep_3b (too small for LLEP to help)
with the 9.5B debugmodel_ep8_llep (dim=2048, moe_inter_dim=1536,
16 layers, lbs=8) that stresses GPU memory and shows LLEP's benefit.

Benchmark results on 8xB200 (steps 5-20):
- Speed: +10.9% mean TPS (26,370 vs 23,780)
- Memory: 7 GiB spread vs 42 GiB, max 82% vs 97%
- Without LLEP at lbs=10: OOM. With LLEP: runs fine.
- Loss correctness: <0.001 diff by step 130
DeepEPLLEPMoE.forward() was identical to DeepEPMoE.forward(). The
behavioral difference comes entirely from which ExpertParallel hooks
get installed (DeepEPExpertParallel vs DeepEPLLEPExpertParallel),
not from the MoE class itself.
4-config comparison on 9.5B model, 8xB200:
- LLEP (standard): 26,250 TPS, 6 GiB spread
- DeepEP+LLEP (adaptive): 25,820 TPS, 7 GiB spread
- Standard EP: 21,940 TPS, 41 GiB spread
- DeepEP only: 19,640 TPS, 52 GiB spread
The fused_silu_gate Triton kernel (fwd/bwd), FusedSiLUGate autograd
class, and silu_gate_reference were not wired into the training path
(llep.py never imported them). Delete the dead code from llep_kernels.py,
remove test_llep_correctness.py (all tests were for the deleted kernel),
and update docs/llep.md to reflect current state.
Move llep.py, llep_kernels.py, llep_autotune.py into
torchtitan/distributed/llep/ and docs/llep.md as its README.md.
All external imports (from torchtitan.distributed.llep import ...)
remain unchanged since llep/__init__.py preserves the public API.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rmal_

Delete fast_init_trunc_normal_ and fast_init_normal_ from moe.py since
upstream utils.py already provides equivalent trunc_normal_ and normal_.
Wrap init_weights() call in test with torch.no_grad() to match how the
training pipeline invokes it.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
feat: Least-Loaded Expert Parallelism with new upstream
@dmahan93 dmahan93 changed the base branch from dev-updated-again to upstream-2026-10-02 March 12, 2026 18:28
@dmahan93 dmahan93 marked this pull request as ready for review March 12, 2026 18:32
@dmahan93 dmahan93 changed the base branch from upstream-2026-10-02 to dev-updated-again March 13, 2026 17:11
@dmahan93 dmahan93 merged commit efa8476 into dev-updated-again Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.