Skip to content

[test] check git diff#47

Draft
xrsrke wants to merge 131 commits intophuc/kimi1t_trainingfrom
upstream-2026-24-01
Draft

[test] check git diff#47
xrsrke wants to merge 131 commits intophuc/kimi1t_trainingfrom
upstream-2026-24-01

Conversation

@xrsrke
Copy link

@xrsrke xrsrke commented Feb 3, 2026

No description provided.

mori360 and others added 30 commits November 10, 2025 10:14
We should be able to control what passes to run in the compiler. This PR
uses the config compile.passes to indicate in a list of graph passes to
apply on the captured gm.

By default, no pass is applied. Users can specify what passes to apply.

Currently there are `autobucketing_reordering_pass` and
`regional_inductor_pass`.

```
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor
```

Also updated CI to include this new config
After some offline discussion, we've concluded that life would be easier
if we can put simplefsdp's checkpoint logic for `reshard_after_forward`
to compiler. The ac annotation part is borrowed form AP:
[LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69).

**Trace and Loss Check** (all with torch.compile enable)

reshard_after_fwd = False
1. SAC + llama3
([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json))
<img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM"
src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a"
/>

<img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM"
src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c"
/>


2. Full AC + llama3 [(trace)]()

<img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM"
src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548"
/>

<img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM"
src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5"
/>


3. No AC + llama3
[[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)]

<img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM"
src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0"
/>

<img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM"
src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af"
/>


reshard_after_fwd = True

1. SAC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json))

<img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM"
src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4"
/>


2. Full AC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json))

<img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM"
src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca"
/>


3. No AC + llama3
([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json))


<img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM"
src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc"
/>
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* pytorch#2002
* __->__ pytorch#2001

Add typing, credit to Claude.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* __->__ pytorch#2013

When full_dtensor is True, the compute_placement will be preserved. This
means that `to_local()` won't be called for fsdp only case. nD
parallelism case (fsdp + tp) will error out as we have not implemented
this case.

This argument doesn't affect the current simple_fsdp. We have verified
`full_dtensor=True` case with the full dtensor skleton PR, which will be
published once it is ready.

**This is a reland PR of
pytorch#2002. The previous one was
broken during rebase.**
Summary:
- we need to pass the global rank information to pytorch so that the pg
name can include the pg information
- this is necessary to differentiate the default pg's on different
replicas
- these need to different because flight recorder matches collectives
based on pg name as well
- add ft training to experiments folder, we'll move remaining pieces of
ft to this gradually but make new features only available through this
folder

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986).
* pytorch#1988
* pytorch#1987
* __->__ pytorch#1986

Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2012
* __->__ pytorch#2011

It is not correct as JobConfig has changed.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* __->__ pytorch#2012
* pytorch#2011

Summary:
The current configuration validation requires torchx and GPUs. It can
waste time, resources, ane engery. Polar bears are crying. Let's fix
this by providing a dry run mode. This PR doesn't verify everything. In
theory, we should be able to verify parallelisms settings as well. This
PR is just a start but it at least can let us catch the typos quickly.
As titled. `_clear_traced_params_buffers` is no longer being used as we
have switched the dynamo graph capture API.
…nly (pytorch#2016)

Addressing following issues in this PR-

- Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
- Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)
…& push to Main branch only" (pytorch#2017)

Reverts PR: pytorch#2016
Addressing following issues in this PR-
- Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
- Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)

Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com>
…2015)

This PR adds the utils to automatically check the training numerics
(losses, grad norms) of two runs to verify if they have bitwise
equivalence.

The added script triggers two runs with user defined configs. Then it
loads metrics saved during training and compare the numerics to verify
bitwise equivalence. Currently we check for losses and grad norms during
training steps

For example, we want to compare the numerics between compiler toolkit
with aot_eager backend and eager on llama3-8B.
```
python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2
```
It'll run `simple_fsdp` experiment without `torch.compile` as the eager
baseline, and `compile_toolkit` experiment as the compiled run. Then it
compares the training numerics of these two runs to verify bitwise
equivalence.

When it is bitwise equivalent, we'll see the following output
```
Starting training: simple_fsdp.llama3
✓ Training completed: simple_fsdp.llama3

Starting training: compiler_toolkit.llama3
✓ Training completed: compiler_toolkit.llama3
  ✓ PASS: All 11 steps match exactly (bitwise equivalent)
  ✓ PASS: All 11 steps match exactly (bitwise equivalent)
✓ SUCCESS: All metrics are bitwise equivalent
```

Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so
that we can guard working parallelism combinations that already have
bitwise equivalence in CI.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2029
* pytorch#2030
* pytorch#2028
* __->__ pytorch#2027
* pytorch#2026

Dry run mode works but it doesn't exit gracefully for all cases. This PR
fixes it

```
DRY_RUN=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh   --training.steps=10 --activation_checkpoint.mode="none"
--debug.deterministic --debug.seed=42
```
…h#2030)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2029
* __->__ pytorch#2030

The current CompileModule will result in an "inner" prefix for
everything. This
PR fixes it by overloading the methods.

Also merge pytorch#2028 to this PR.
Something wrong with ghstack.
This PR adds support for aten-level manual bucketing in
SimpleFSDP+`aot_eager` backend. Dependent on PyTorch
[PR](pytorch/pytorch#165487)

TODO List:
- [ ] We should have better way of handling region info other than a
list of str FQNs in current `manual_bucketed_modules`. It would be very
easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )
- [ ] Currently, the reordering happens under the hood and overlap with
last/next compute. We should allow users to specify which module they
want to reorder.
- [ ] Loss difference on multi-node training
- [ ] DSV3 manual bucketing

I'll address the TODO items in follow up PRs. Let's start with this
simple FSDP+TP+llama3 PR.

1. Performance (FSDP2 under eager mode, SimpleFSDP uses `aot_eager`
backend)

**Llama 3-8B**

* Performance (All Batch_size = 1). (The slower TPS on Single Node is
sort of as expected, since FSDP2 handles copy-in/out in two different
streams, whereas SimpleFSDP handles copy-in/out in the same stream)

|Node| Method | Parallelism | Memory | TPS | Trace|
|---------|---------|-----------|----------|------|------|
|1-Node (8H100)|SimpleFSDP | FSDP=8| 40.96GiB(43.12%) | 7,227|
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-48-48_rank0_trace.json)|
|1-Node (8H100)|FSDP2-eager| FSDP=8| 47.82GiB(50.35%) | 7,380 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-54-14_rank0_trace.json)|
|8-Node (64H100)|SimpleFSDP| FSDP=64  | 29.37GiB | 4,984| |
|8-Node (64H100)|FSDP2| FSDP=64 | 31.41GiB  |5,097 | |
|1-Node (8H100)|SimpleFSDP| FSDP=4 TP=2 | 28.28GiB(29.77%) | 5,881 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-18-00-18_rank0_trace.json)
|
|1-Node (8H100)|FSDP2| FSDP=4 TP=2 | 35.33GiB(37.20%) | 5,898 |
[LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-15-35-47_rank0_trace.json)
|
|8-Node (64H100)|SimpleFSDP| FSDP=8 TP=8  |   |||
|8-Node (64H100)|FSDP2| FSDP=8 TP=8 |   |||

Example SimpleFSDP 1D overlapping trace:

<img width="1127" height="127" alt="Screenshot 2025-10-16 at 10 49
55 AM"
src="https://github.com/user-attachments/assets/2d9e3ff8-8e9b-40a7-a666-3c0a0975186e"
/>

Example SimpleFSDP 2D overlapping trace:
<img width="1162" height="166" alt="Screenshot 2025-10-26 at 6 00 51 PM"
src="https://github.com/user-attachments/assets/bc5cc031-5b6c-4e4d-a9da-70c43114f49a"
/>


- Bitwise Loss:

FSDP-only:
<img width="1266" height="837" alt="Screenshot 2025-10-17 at 10 41
56 AM"
src="https://github.com/user-attachments/assets/30f83d95-1eca-4f10-9e7e-47c45278cd8d"
/>

FSDP+TP:
<img width="1259" height="808" alt="Screenshot 2025-10-26 at 9 03 58 PM"
src="https://github.com/user-attachments/assets/b75b452b-adb9-4078-9412-ee9e584ffe15"
/>
The current `convert_to_hf.py` does not support `export_dtype`, which
makes it `float32` by default. This PR adds support for export dtypes of
`["float16", "bfloat16", "float32"]`.
This PR integrates the changes in pytorch#1970 to compiler toolkit (applying
`joint_ac_pass` on the joint graph graph to tag nodes based on
`reshard_after_forward` flag)

Also did some refactor for applying graph passes in compiler toolkit
experiments. We will have two kinds of passes

1. joint_custom_passes: these are passes to be applied on the captured
joint graph before partitioner. By default we
`validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass`

2. compiler_passes: there are passes to be applied on partitioned fwd
and bwd graphs as backend optimizations. By default there is none. We
can indicate `autobucketing_reordering_pass` and
`regional_inductor_pass` using configs.
…ytorch#2056)

This PR integrates the manual bucketing pass (transformer block
bucketing) added in SimpleFSDP experiment (pytorch#1881) to compiler toolkit

So now compiler toolkit can also run manual bucketing pass by specifying
the config

```
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing
``` 

Also updated README and integration test to include the newly ported
pass
…h only (pytorch#2018)

Addressing following issues in this PR-

Running Torchtitan ROCm workflow on cron schedule & only when push to
Main branch. CUDA workflow will run as is.
Refactor Torchtitan test run to address older PR comment
pytorch#1786 (comment)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2049
* __->__ pytorch#2029


## Summary
This PR adds `scripts/loss_compare.py` for comparing training losses
between different git commits and/or training configurations.

## Key Features

- Commit Comparison: Compare losses between two different git commits
with deterministic training
- Configuration Comparison: Compare different training configurations on
the same commit
- Reproducibility: Automatically enables deterministic mode and seed
checkpointing for reproducible
  comparisons
- Real-time Output: Streams training output to both console and log
files during execution
- Statistical Analysis: Generates step-by-step loss comparisons and
summary statistics
- CI Testing: Includes --assert-equal flag for automated testing to
verify identical losses

## Usage Examples

#### Compare two commits
```
python3 ./scripts/loss_compare.py main my_branch
```
#### Compare two commits with custom configuration 
```
python3 ./scripts/loss_compare.py main my_branch \
--baseline-config="./custom.toml" 
--baseline-options="--parallelism.tensor_parallel_degree=2"  \
```

#### Compare different parallelization strategies on same commit
```
python3 ./scripts/loss_compare.py . . \
--baseline-config="./llama3_8b.toml" 
--baseline-options="--parallelism.tensor_parallel_degree=2" \
--test-options="--parallelism.tensor_parallel_degree=1" \
```

#### Assert equality for CI testing
```
python3 ./scripts/loss_compare.py main my_branch --assert-equal
```


## Real Use Cases
Compare full dtensor simple fsdp with fsdp2:
```
python3 scripts/loss_compare.py . . \
--baseline-options='--activation_checkpoint.mode="none"'  \
--test-train-file='torchtitan.experiments.full_dtensor.train' \ 
--test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none"'  \
 --assert-equal --no-seed-checkpoint


[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok
```
All tests in experiments are broken due to the `gpu_arch_type` field
added in pytorch#2018.
…#2064)

Adding CudaGraph pass (pytorch#2050)
would require some custom logic in Trainer's close() method.

So we create a Trainer subclass in compiler toolkit
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#2063
* __->__ pytorch#2062

This will prevent errors when later doing git checkout
## Features
- [x] Support SimpleFSDP and TP
- [x] Support static input indices to reduce copy
- [x] Support memory reuse to reduce memory consumption
- [x] Cleanup cudagraph when training finishes to avoid nccl hang from
destroy_process_group

Command:
```
NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4  --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph
```


Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl +
cudagraphs + expandable segments result in IMA.
pytorch/pytorch#158029


[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json)

## Result

**Numerics:**
Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B
AND llama3.1-70B.

**Performance:**
<img width="560" height="90" alt="image"
src="https://github.com/user-attachments/assets/9d54c461-0eb1-4f7e-9652-3d52043ad74f"
/>

Raw log:
[llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190),
[llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416)

**Memory:**
On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs
153 GiB).

A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph
as an example):
- Start: 161 GiB
- \+ use the same stream for warmup and graph capture of both fwd and
bwd: 160 GiB
- \+ warmup in cudagraph memory pool instead of eager memory pool: 153
GiB


**static input copy:**
On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for
backward, we copy 1 tensor of 0.98 GB. This shows static input indices
is handled correctly.


## Followup PR
In the followup PR, I will enable fx graph partition for deepseek v3
pytorch/pytorch#165945.
This PR fixes access to args; it's an attribute, not a variable in the
scope.
The method itself though would not be used because
`should_check_address` seems to be always `False` and there doesn't seem
to be a command line argument for it.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
# Context
Reference PR: huggingface#1

This PR enables:
- Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP
(and the combinations between them). The following models were tested:
  - `meta-llama/Llama-3.2-1B`
  - `microsoft/phi-2`
  - `Qwen/Qwen2.5-7B`
  - `mistralai/Mistral-7B-v0.1`
  - `ByteDance-Seed/Seed-Coder-8B-Instruct`
  - `Qwen/Qwen3-4B-Instruct-2507`
  - `arcee-ai/AFM-4.5B`
  - `ibm-granite/granite-3b-code-base-2k`
  - `baidu/ERNIE-4.5-0.3B-Base-PT`
  - `kyutai/helium-1-preview-2b`
  - `allenai/OLMo-7B-hf`
  - `mistralai/Ministral-8B-Instruct-2410`
- Patching HF models weights initialisation. Without this, the the
`loss` and `grad_norm` starts very high

# Usage

- Requirements `transformers==4.57.1`
- Config:
`torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7
CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml
./run_train.sh
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config
--compile.enable`

<img width="1334" height="453" alt="image"
src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c"
/>

# Testing methodology

<img width="2672" height="2018" alt="image"
src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7"
/>

- Following the
[converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md)
guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other
//-ism>`
- More precisely, the `test_hf_integration.py`is going to do:

```bash
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
```
- Here is the grid search to test the HF modelling
```shell
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done
```

# Further tasks

- Moe (handle in PR huggingface#3)
	- Missing `build_optimizers_with_moe_load_balancing` support for MoE
	- Missing TP/PP/EP supports for MoE 
- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss`
and `grad_norm` not bitwise matching (but converging) while it is the
case with Torchtitan modeling. (issue is tracked in
huggingface#4)
- Add convergence tests to CI by doing tiny model + gloo backend (once
PP is bitwise matching)
- the HF modeling has lower MFU than Torchtitan MFU
- NOTE: `import torch._dynamo.config;
torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for
graph when using `torch.compile` and `activation checkpointing`
**Summary**
This PR adds variable length attention (varlen) support to the Llama 3
8b model in torchtitan. We replace `use_flex_attn` with `attn_type`
(either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the
attention module calls a compiled `varlen_attn` defined
[here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py).

**Testing**
Ran loss and performance tests against flex attention. Loss is on par.

<img width="947" height="505" alt="Screenshot 2025-11-19 at 3 24 26 PM"
src="https://github.com/user-attachments/assets/d85dfc09-4f5e-4f82-abc9-49b870b34990"
/>

Varlen is slightly slower than Flex due to the cuda kernel speeds
(varlen calls into `flash_attention_forward`/`flash_attention_backward`
today).


| | Varlen | Flex |
| :---: | :------ | :---: |
| Forward  | 774us 357ns | 722us 317ns  |
| Backward   | 1ms 955us 916ns  | 1ms 558us 747ns    |
francesco-bertolotti and others added 30 commits January 18, 2026 23:08
Added troubleshooting tip for missing libnvshmem_host.so.
* Support mxfp8 on gfx950.

It depends on TorchAO (pytorch/ao#3620).
Update README with libnvshmem_host.so troubleshooting
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* __->__ pytorch#2260
* pytorch#2246
* pytorch#2245

This API is not needed any more. pyrefly will give us a warning.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* pytorch#2145
* __->__ pytorch#2144

**Summary**

1. Refactored CP Dispatching:
- New apply_cp() function uses PyTorch's _ContextParallel
parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
- Added a cp_shard() helper function that wraps PyTorch's
_context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
- CP sharding now happens explicitly in post_dataloading_process() where
inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
- This PR will invalidate torch.compile + CP due to
pytorch/pytorch#170110. We will have to wait
for Dynamo to fix the issue or refactor nn.Module core logic to avoid
check hook_id.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* __->__ pytorch#2145

Summary:

Continue the previous PR, this PR enable FlexAttention + CP for llama3.
FlexCP will use PTRRLoadBalancer.

Note that this PR requires
pytorch/pytorch#170201
Previously, individual experts are marked as `Replicate` in EP dimension
in global `global_device_mesh`. Local experts are first created on
`global_device_mesh` and are turned into a 2d tensor using `squeeze(0)`,
which only removes the extra dimension, but the remaining metadata
`Replicate` is still there. The wrong metadata results in bug when DCP
saves `DTensor`. This PR fixes this bug by:
1. Use a sub-mesh that excludes expert dimension, i.e., dim 0. 
2. When sub-mesh is empty, use plain tensor instead of `DTensor`.
Added installation of transformers package and updated sbatch script instructions.
`is_causal` flag has been deprecated in `varlen_attn`, use `window_size
= [-1, 0]` instead, see [this
PR](pytorch/pytorch#172245)

*Test*
<img width="1225" height="496" alt="Screenshot 2026-01-21 at 11 55
59 AM"
src="https://github.com/user-attachments/assets/125c56af-76ed-4f6f-a433-da4d9a631dab"
/>
In this PR we added ROCm CI support for simple fsdp experiments test.
…nfigs

Memory Tracking Tools:
- Add DetailedMemoryTracker for per-phase memory tracking (before_forward,
  after_forward_backward, after_optimizer, step_end)
- Add CUDAMemoryTracker for PyTorch vs nvidia-smi memory comparison
- Add AggressiveMemoryManager for CUDA fragmentation reduction with modes:
  minimal, balanced, aggressive, maximum

BF16 Optimizer States:
- Add BF16StateOptimizersContainer wrapper for pre-initializing optimizer
  states in bfloat16 before first step (50% memory savings)
- Add preinit_optimizer_states_bf16() to allocate exp_avg/exp_avg_sq in
  param dtype from the start, avoiding fp32 allocation spike
- Fix device mismatch bug: state["step"] tensor now created on param device

New Config Options:
- optimizer.state_dtype: "float32" | "bfloat16"
- training.enable_detailed_memory_tracking: bool
- training.clear_cache_between_steps: bool
- training.skip_optimizer_step: bool
- training.aggressive_memory_mode: "minimal" | "balanced" | "aggressive" | "maximum"
- training.aggressive_memory_verbose: bool

Train Loop Integration:
- Initialize memory trackers in Trainer.__init__
- Call tracking at forward_backward_step and train_step phases
- Call aggressive memory manager at post_backward, post_optimizer, step_end
- Pre-initialize BF16 optimizer states before training loop

Configs Added:
- qwen3_30b_a3b_memory_test.toml: Test config for memory features
- kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml: 12-node production config
- kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml: 36-node HSDP config
…reporting

Config Options Added:

[parallelism]
- fsdp_disable_prefetch: Disable FSDP forward/backward prefetching
  (reduces memory at cost of less communication overlap)

[debug]
- enable_nan_tracker: Enable lightweight NaN/Inf tracking to find
  where NaN first appears in the model
- nan_tracker_verbose: Print stats for every layer (very verbose)

Enhanced Metrics:
- Add nvidia-smi memory reporting to DeviceMemStats for verification
- Add _get_nvidia_smi_memory() method to DeviceMemoryMonitor
- Handles CUDA_VISIBLE_DEVICES remapping for SLURM environments

FSDP Prefetch Control:
- Add disable_prefetch parameter to apply_fsdp() in llama4
- Wire up fsdp_disable_prefetch config to apply_fsdp calls in:
  llama4, deepseek_v3, qwen3, gpt_oss

Test Configs:
- qwen3_30b_a3b_memory_test.toml: Added [debug] section with nan_tracker

Note: fsdp_bucket_cap_mb not added as it's not supported by FSDP2 API
Allow fsdp_reshard_after_forward to accept an integer N for partial
resharding to N-GPU groups. This reduces peak memory by limiting
all-gather buffer size to N GPUs instead of full DP world.

Use N=8 for intra-node resharding (fast NVLink communication).
N must be a factor of the FSDP shard world size.

Example: fsdp_reshard_after_forward = 8

Changes:
- config/job_config.py: Update type to Literal[...] | int
- llama4/parallelize.py: Handle int values in apply_fsdp()
Standalone diagnostic tool that visualizes GPU allocation across
parallelism dimensions (DP, PP, TP, CP, EP). Only runs on rank 0
at initialization - no impact on training performance.

Features:
- Mesh structure visualization
- GPU allocation grid
- Expert parallel group allocation
- Context parallel group allocation
- FSDP sharding visualization

Integrated into train.py to automatically log at startup.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.