From ae0349410ead46298e65bc242c023a8d09db7ac2 Mon Sep 17 00:00:00 2001 From: yifanmao Date: Mon, 10 Nov 2025 10:14:09 -0800 Subject: [PATCH 001/127] [TorchComms] add testing badge at experiments readme (#2010) --- torchtitan/experiments/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 0d6db0d2a1..14f8ba6544 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -27,7 +27,7 @@ We provide this `experiments/` folder to host experiments that add significant v | [simple_fsdp](./simple_fsdp/) | [![SimpleFSDP 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) | [@ruisizhang123](https://github.com/ruisizhang123) [@tianyu-l](https://github.com/tianyu-l) | | [vlm](./vlm/) | [![VLM 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml?query=branch%3Amain) | [@lkhphuc](https://github.com/lkhphuc) | | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | -| [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | +| [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | From f4514efce69d1c6651aa5100296719f0296f2241 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:45:15 -0800 Subject: [PATCH 002/127] [compiler toolkit] specify passes through config (#2006) We should be able to control what passes to run in the compiler. This PR uses the config compile.passes to indicate in a list of graph passes to apply on the captured gm. By default, no pass is applied. Users can specify what passes to apply. Currently there are `autobucketing_reordering_pass` and `regional_inductor_pass`. ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` Also updated CI to include this new config --- .../experiments/compiler_toolkit/README.md | 11 +++ .../compiler_toolkit/common_utils.py | 10 +++ .../deepseek_v3/parallelize.py | 43 +++++---- .../compiler_toolkit/graph_utils.py | 88 ++++++++++++++++++- .../compiler_toolkit/job_config.py | 23 +++++ .../compiler_toolkit/llama3/parallelize.py | 62 ++++--------- .../experiments/compiler_toolkit/passes.py | 46 ++++++++++ .../tests/integration_tests.py | 31 +++++++ 8 files changed, 248 insertions(+), 66 deletions(-) create mode 100644 torchtitan/experiments/compiler_toolkit/job_config.py create mode 100644 torchtitan/experiments/compiler_toolkit/passes.py diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index a75b3a17b2..61207fc63b 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -29,7 +29,18 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 ``` +**SimpleFSDP + TP + auto-bucketing** +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering +``` + **SimpleFSDP + TP + FlexAttention** ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn ``` + +**SimpleFSDP + TP + FlexAttention + auto-bucketing + regional-inductor** + +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor +``` diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..b7499b2f79 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -53,3 +53,13 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def validate_flex_attention_annotation(joint_with_descriptors): + """Verify user annotations show up in the graph.""" + for node in joint_with_descriptors.graph_module.graph.nodes: + if node.target in { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + }: + assert "compile_with_inductor" in node.meta.get("custom", {}) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 5c8ffb45c5..bc6859af61 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -17,37 +17,19 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, + validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, + get_compiler_passes_from_config, joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3, ) -from torchtitan.tools.logging import logger - - -def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) - - # TODO: regional_inductor should work with deepseek_v3 - # gm = regional_inductor(gm, example_inputs) - - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) - return gm - - -def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) - - -def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) def annotate_deepseekv3() -> None: @@ -75,7 +57,17 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + """ + Parallelize and compile a DeepSeek v3 model with optional custom compiler passes. + + Args: + model: The model to parallelize + parallel_dims: Parallel dimensions configuration + job_config: Job configuration + Returns: + CompiledModule wrapping the parallelized and compiled model + """ annotate_deepseekv3() register_blockmask_pytree_node() @@ -84,11 +76,18 @@ def parallelize_deepseekv3( with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + # Get compiler passes from config + compiler_passes = get_compiler_passes_from_config(job_config) + + # Create compilers with specified passes (defaults to no passes) + fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + + # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=None, + joint_custom_pass=validate_flex_attention_annotation, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 4ff6c8187b..dbe0a8a257 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Callable, Optional +from typing import Callable, List, Optional import torch from torch._dynamo.functional_export import dynamo_graph_capture_for_export @@ -16,6 +16,7 @@ ) from torch._guards import tracing, TracingContext from torch.distributed.tensor import DTensor +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -180,3 +181,88 @@ def forward(self, *args, **kwargs): # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. return self.joint_graph_module(args, kwargs) + + +# Default compiler pass configuration - no passes by default +DEFAULT_COMPILER_PASSES = [] + + +def compiler( + name: str, + gm: torch.fx.GraphModule, + example_inputs, + passes: List[Callable] = None, +): + """ + Compile a graph module by applying a sequence of compiler passes. + + Args: + name: Name for logging purposes + gm: The graph module to compile + example_inputs: Example inputs for the graph module + passes: List of compiler pass functions to apply. Each function should take + (gm, example_inputs) and return a transformed gm. If None, uses + DEFAULT_COMPILER_PASSES. + """ + if passes is None: + passes = DEFAULT_COMPILER_PASSES + + logger.info(f"{name} before compiler:") + logger.info(gm.print_readable(print_output=False)) + + for pass_fn in passes: + logger.info(f"Applying pass: {pass_fn.__name__}") + gm = pass_fn(gm, example_inputs) + + logger.info(f"{name} after compiler:") + logger.info(gm.print_readable(print_output=False)) + return gm + + +def make_compiler_with_passes(passes: List[Callable] = None): + """ + Create forward and backward compilers with specified passes. + + Args: + passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. + + Returns: + Tuple of (fw_compiler, bw_compiler) functions + """ + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + return compiler("fwd_gm", gm, example_inputs, passes=passes) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + return compiler("bwd_gm", gm, example_inputs, passes=passes) + + return fw_compiler, bw_compiler + + +def get_compiler_passes_from_config(job_config: JobConfig): + """ + Extract and validate compiler passes from job config. + + Args: + job_config: Job configuration containing compile.passes + + Returns: + List of compiler pass functions + """ + from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES + + pass_names = getattr(job_config.compile, "passes", []) + compiler_passes = [] + + for pass_name in pass_names: + if pass_name not in AVAILABLE_PASSES: + raise ValueError( + f"Unknown compiler pass: {pass_name}. " + f"Available passes: {list(AVAILABLE_PASSES.keys())}" + ) + compiler_passes.append(AVAILABLE_PASSES[pass_name]) + + if pass_names: + logger.info(f"Using compiler passes from config: {pass_names}") + + return compiler_passes diff --git a/torchtitan/experiments/compiler_toolkit/job_config.py b/torchtitan/experiments/compiler_toolkit/job_config.py new file mode 100644 index 0000000000..ec5829a6c9 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/job_config.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class Compile: + """ + List of compiler pass names to apply in the compiler toolkit workflow. + By default, no passes are applied. + Example: --compile.passes autobucketing_reordering,regional_inductor + """ + + passes: list[str] = field(default_factory=list) + + +@dataclass +class JobConfig: + compile: Compile = field(default_factory=Compile) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 0ed8452148..e3dca203e9 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -8,9 +8,6 @@ import functools import torch -from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing - -from torch.fx.passes.regional_inductor import regional_inductor from torch.fx.traceback import annotate_fn from torchtitan.config import JobConfig @@ -19,56 +16,19 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, + validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, + get_compiler_passes_from_config, joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, ) -from torchtitan.tools.logging import logger - - -# TODO: support passing configs into schedule_overlap_bucketing -def autobucketing_reordering_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm, collective_bucketing=True) - gm.recompile() - return gm - - -def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) - - gm = autobucketing_reordering_pass(gm) - - gm = regional_inductor(gm, example_inputs) - - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) - return gm - - -def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) - - -def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) - - -def validate_flex_attention_annotation(joint_with_descriptors): - """Verify user annotations show up in the graph.""" - for node in joint_with_descriptors.graph_module.graph.nodes: - if node.target in { - torch.ops.higher_order.flex_attention, - torch.ops.higher_order.flex_attention_backward, - }: - assert "compile_with_inductor" in node.meta.get("custom", {}) - def annotate_llama() -> None: from torchtitan.models.attention import FlexAttentionWrapper @@ -84,7 +44,17 @@ def parallelize_llama( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + """ + Parallelize and compile a Llama model with optional custom compiler passes. + + Args: + model: The model to parallelize + parallel_dims: Parallel dimensions configuration + job_config: Job configuration + Returns: + CompiledModule wrapping the parallelized and compiled model + """ annotate_llama() register_blockmask_pytree_node() @@ -93,6 +63,12 @@ def parallelize_llama( with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) + # Get compiler passes from config + compiler_passes = get_compiler_passes_from_config(job_config) + + # Create compilers with specified passes (defaults to no passes) + fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( joint_graph_builder, diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py new file mode 100644 index 0000000000..1c00fd5c1b --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Compiler passes for the compiler toolkit. + +This module provides various compiler passes that can be applied to graph modules +during compilation. Passes can be selected and configured via job config. +""" + +import torch +from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.fx.passes.regional_inductor import regional_inductor + + +def autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs=None +) -> torch.fx.GraphModule: + """ + Apply autobucketing and reordering optimization. + + This pass applies schedule_overlap_bucketing with collective_bucketing enabled + to optimize communication patterns in distributed training. + """ + schedule_overlap_bucketing(gm, collective_bucketing=True) + gm.recompile() + return gm + + +def regional_inductor_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply regional inductor compilation based on user annotation. + """ + return regional_inductor(gm, example_inputs) + + +# Registry mapping pass names to pass functions +AVAILABLE_PASSES = { + "autobucketing_reordering": autobucketing_reordering_pass, + "regional_inductor": regional_inductor_pass, +} diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index bb64160db2..e33149fe2f 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -31,6 +31,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--activation_checkpoint.mode none", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering", + ], + ], + "llama3 FSDP+TP autobucketing", + "llama3_fsdp_tp_autobucketing", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -45,6 +60,22 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--activation_checkpoint.mode none", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering,regional_inductor", + ], + ], + "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor", + "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ From 02990b0fe5e6cc9a6dfd74821be3c4bb89ab0c65 Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Mon, 10 Nov 2025 14:58:09 -0800 Subject: [PATCH 003/127] [simplefsdp] fix region ac in zero2-style FSDP (#1970) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After some offline discussion, we've concluded that life would be easier if we can put simplefsdp's checkpoint logic for `reshard_after_forward` to compiler. The ac annotation part is borrowed form AP: [LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69). **Trace and Loss Check** (all with torch.compile enable) reshard_after_fwd = False 1. SAC + llama3 ([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json)) Screenshot 2025-10-30 at 4 28 59 PM Screenshot 2025-11-05 at 9 02 30 PM 2. Full AC + llama3 [(trace)]() Screenshot 2025-10-30 at 4 30 53 PM Screenshot 2025-11-05 at 9 11 34 PM 3. No AC + llama3 [[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)] Screenshot 2025-10-30 at 4 32 05 PM Screenshot 2025-11-05 at 9 07 46 PM reshard_after_fwd = True 1. SAC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json)) Screenshot 2025-10-31 at 11 34 47 AM 2. Full AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json)) Screenshot 2025-10-31 at 11 38 02 AM 3. No AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json)) Screenshot 2025-10-31 at 11 43 04 AM --- torchtitan/experiments/simple_fsdp/README.md | 4 +- torchtitan/experiments/simple_fsdp/backend.py | 35 ++++++-- .../simple_fsdp/deepseek_v3/parallelize.py | 49 +++++----- .../simple_fsdp/llama3/parallelize.py | 33 ++++--- .../simple_fsdp/reshard_after_forward.py | 90 +++++++++++++++++++ .../experiments/simple_fsdp/simple_fsdp.py | 53 +---------- 6 files changed, 167 insertions(+), 97 deletions(-) create mode 100644 torchtitan/experiments/simple_fsdp/reshard_after_forward.py diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index a49fa8ad56..ea4fb3272f 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -3,11 +3,13 @@ [![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) [![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284) -💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via +💡 **Note 1**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via ```bash pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall ``` +💡 **Note 2**: Some of SimpleFSDP's functionalities (e.g., reshard_after_forward) is implemented with torch.compile. It is always recommended to open compile (`--compile.enable`) to see desired correct functionality. + This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. ### Run SimpleFSDP Training on Llama3 & DeepSeek_v3 diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 36abe4ad0b..d51e6668c1 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -4,20 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Union +from typing import Any import torch +import torch._functorch.config as functorch_config +from .reshard_after_forward import annotate_fsdp_all_gather -def get_compile_backend(backend_name: str) -> Union[str, callable]: + +def get_compile_backend( + backend_name: str, fsdp_reshard_after_forward: bool +) -> callable: # return the compile backends used in SimpleFSDP training # Step1: check if backend_name is inside available torch.compile backends # Step2: check if the backend_name has been registered as a customized backend available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) - if backend_name in available_torch_backend: - return backend_name - if backend_name == "aot_eager_autobucketing": + if backend_name in available_torch_backend: + backend = torch._dynamo.lookup_backend(backend_name) + elif backend_name == "aot_eager_autobucketing": # Perform auto optimization in aten fx-level and execute code in aot_eager backend # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend @@ -46,4 +51,22 @@ def aten_autobucketing_reordering_pass( else: raise AssertionError(f"Unsupported customized backend: {backend_name}") - return backend + def joint_ac_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) + gm.recompile() + return gm + + def simple_fsdp_custom_pass(*args, **kwargs): + # the ac pass has to operate in a joint graph before partitioner for ac + # annotation to take into effect. + with functorch_config.patch("joint_custom_pass", joint_ac_pass): + return backend(*args, **kwargs) + + return simple_fsdp_custom_pass diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index ac6f9bdc9b..2ae1c517f3 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -10,16 +10,18 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.deepseek_v3.infra.parallelize import ( - apply_ac, apply_moe_ep_tp, apply_non_moe_tp, ) from torchtitan.tools.logging import logger -from ..simple_fsdp import data_parallel, MixedPrecisionPolicy +from ..backend import get_compile_backend +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( @@ -91,20 +93,6 @@ def parallelize_deepseekv3( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - # apply data parallel dp_mesh: DeviceMesh | None = None if ( @@ -155,9 +143,7 @@ def parallelize_deepseekv3( transformer_block.moe.experts, dp_mod_ep_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, shard_dim=experts_shard_dim, reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -166,9 +152,7 @@ def parallelize_deepseekv3( model, dp_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( @@ -178,6 +162,29 @@ def parallelize_deepseekv3( if job_config.compile.enable: torch._inductor.config.reorder_for_peak_memory = False torch._dynamo.config.capture_scalar_outputs = True - model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True) + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + backend = ( + getattr(job_config.compile, "model_backend_override", None) + or job_config.compile.backend + ) + model = torch.compile( + model, + backend=get_compile_backend(backend, fsdp_reshard_after_forward), + fullgraph=True, + ) return model diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index d61e74a5dd..1d8bfc500f 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -112,27 +112,11 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - model = data_parallel( model, parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], mode=dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode @@ -140,13 +124,28 @@ def parallelize_llama( if job_config.compile.enable and "model" in job_config.compile.components: torch._inductor.config.reorder_for_peak_memory = False + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + backend = ( getattr(job_config.compile, "model_backend_override", None) or job_config.compile.backend ) model = torch.compile( model, - backend=get_compile_backend(backend), + backend=get_compile_backend(backend, fsdp_reshard_after_forward), fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/reshard_after_forward.py b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py new file mode 100644 index 0000000000..dac010bfcd --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils.checkpoint import CheckpointPolicy + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: + """ + Returns True if the node is a wait_tensor node that is the result of an all_gather + that can be arbitrarily prefetched, i.e., if all its recursive inputs are + single-input operators that leads to a graph input. + """ + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): + n: torch.fx.Node = node.all_input_nodes[0] + while len(n.all_input_nodes) == 1: + if is_graph_input(n.all_input_nodes[0]): + return True + n = n.all_input_nodes[0] + return False + + +def annotate_fsdp_all_gather( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> None: + """ + Force recompute all_gather nodes from simple fsdp in the graph. + This pass should be added in torch._inductor.config.joint_custom_post_pass + """ + graph = gm.graph + + def force_recompute_node(node): + if reshard_after_forward: + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + else: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + # ac_graph_id is used in the partitioner to decide + # if two nodes which have AC applied come from a different + # AC regions. This is needed because nodes in the boundary + # of two AC regions are marked as MUST_SAVE. In our case + # we just add a large value of ac_graph_id so that + # all nodes we tag for recomputation do indeed get recomputed + # and are not influenced by other nodes in the graph with + # nearby ac_graph_id values + node.meta["ac_graph_id"] = 100000 + + # Make all-gather nodes (and related nodes) recomputable, to circumvent + # https://github.com/pytorch/pytorch/issues/136433 + for node in graph.nodes: + if is_wait_tensor_from_fsdp(node): + ag_node = node.args[0] + force_recompute_node(ag_node) # all_gather + force_recompute_node(node) # wait_tensor + # Force-recompute slice that comes after wait + for user in node.users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.slice.Tensor + ): + force_recompute_node(user) + # Force-recompute potential dtype casts from all_gather + if ( + ag_node.all_input_nodes[0].op == "call_function" + and ag_node.args[0].target + == torch.ops.prims.convert_element_type.default + ): + force_recompute_node(ag_node.all_input_nodes[0]) + + return gm diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 737b6d3ec2..c731da3800 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -22,12 +22,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import _StridedShard, Placement -from torch.utils.checkpoint import ( - checkpoint, - CheckpointPolicy, - create_selective_checkpoint_contexts, -) - _active_parametrization = True @@ -183,34 +177,13 @@ def _register_parametrization( module.__class__ = module_cls -def fsdp_policy(): - def _fsdp_recomp_policy(): - def _custom_policy(ctx, func, *args, **kwargs): - to_recompute = func in { - torch.ops._c10d_functional.all_gather_into_tensor.default, - torch.ops._c10d_functional.wait_tensor.default, - torch.ops.aten._to_copy.default, # for dtype cast in FSDP - } - return ( - CheckpointPolicy.MUST_RECOMPUTE - if to_recompute - else CheckpointPolicy.MUST_SAVE - ) - - return _custom_policy - - return create_selective_checkpoint_contexts(_fsdp_recomp_policy()) - - class ReplicateComputation(torch.nn.Module): def __init__( self, device_mesh, param_sharding, mode, - regional_ac, mp_policy, - reshard_after_forward, reduction_divide_factor, ): super().__init__() @@ -225,11 +198,9 @@ def __init__( if reduction_divide_factor is not None else Partial(reduce_op="avg") ] * self.device_mesh.ndim - self.regional_ac = regional_ac mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - self.reshard_after_forward = reshard_after_forward def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -292,21 +263,7 @@ def forward(self, x: DTensor) -> torch.Tensor: if not _active_parametrization: return x - if ( - self.regional_ac - and self.mode in ("fully_shard", "hybrid_shard") - and self.reshard_after_forward - ): - # apply checkpointing to implement reshard_after_forward - output = checkpoint( - self.replicate_compute, - x, - use_reentrant=False, - context_fn=fsdp_policy, - ) - else: - output = self.replicate_compute(x) - + output = self.replicate_compute(x) return output @@ -314,9 +271,7 @@ def data_parallel( model: nn.Module, device_mesh: DeviceMesh, mode: str = "replicate", - ac_mode: str = "none", mp_policy: MixedPrecisionPolicy | None = None, - reshard_after_forward: bool = True, shard_dim: int = 0, reduction_divide_factor: float | None = None, ): @@ -335,9 +290,6 @@ def data_parallel( modules = list(model.modules()) - # apply regional ac (with fsdp_policy) if no global ac is to be applied - regional_ac = ac_mode == "none" - for mod in modules: params_dict = dict(mod.named_parameters(recurse=False)) # we shouldn't apply data parallel to the modules that are already @@ -366,7 +318,6 @@ def data_parallel( # device_mesh, # param_sharding, # mode, - # regional_ac, # mp_policy=mp_policy, # ), # unsafe=True, @@ -379,9 +330,7 @@ def data_parallel( device_mesh, param_sharding, mode, - regional_ac, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, reduction_divide_factor=reduction_divide_factor, ), ) From fddd9ebbbae1d042634cafb2c24db398bc502731 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 10 Nov 2025 18:39:28 -0800 Subject: [PATCH 004/127] [SimpleFSDP] Add typing to simple_fsdp.py (#2001) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #2002 * __->__ #2001 Add typing, credit to Claude. --- .../experiments/simple_fsdp/simple_fsdp.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index c731da3800..391ae74dff 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Sequence +from collections.abc import Generator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -27,7 +27,7 @@ @contextmanager -def disable_active_parametrization(): +def disable_active_parametrization() -> Generator[None, None, None]: global _active_parametrization try: _active_parametrization = False @@ -180,18 +180,18 @@ def _register_parametrization( class ReplicateComputation(torch.nn.Module): def __init__( self, - device_mesh, - param_sharding, - mode, - mp_policy, - reduction_divide_factor, - ): + device_mesh: DeviceMesh, + param_sharding: tuple[Placement, ...], + mode: str, + mp_policy: MixedPrecisionPolicy | None, + reduction_divide_factor: float | None, + ) -> None: super().__init__() self.device_mesh = device_mesh self.param_sharding = param_sharding self.mode = mode - self.compute_placements = [Replicate()] * self.device_mesh.ndim - self.grad_placements = [ + self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim + self.grad_placements: list[Placement] = [ _ScaledPartial( reduction_divide_factor=reduction_divide_factor, ) @@ -199,8 +199,8 @@ def __init__( else Partial(reduce_op="avg") ] * self.device_mesh.ndim mp_policy = mp_policy or MixedPrecisionPolicy() - self.param_dtype = mp_policy.param_dtype - self.reduce_dtype = mp_policy.reduce_dtype + self.param_dtype: torch.dtype | None = mp_policy.param_dtype + self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -274,7 +274,8 @@ def data_parallel( mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, reduction_divide_factor: float | None = None, -): +) -> nn.Module: + param_sharding: tuple[Placement, ...] if mode == "replicate": param_sharding = (Replicate(),) elif mode == "fully_shard": From e37f83f58b35fdbceed9a5916b3490c16247ac9c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 10 Nov 2025 23:55:02 -0800 Subject: [PATCH 005/127] [Full DTensor][Reland] Add full_dtensor flag (#2013) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2013 When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. **This is a reland PR of https://github.com/pytorch/torchtitan/pull/2002. The previous one was broken during rebase.** --- torchtitan/experiments/simple_fsdp/simple_fsdp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 391ae74dff..6597c45f9d 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -185,6 +185,7 @@ def __init__( mode: str, mp_policy: MixedPrecisionPolicy | None, reduction_divide_factor: float | None, + full_dtensor: bool = False, ) -> None: super().__init__() self.device_mesh = device_mesh @@ -201,6 +202,7 @@ def __init__( mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype: torch.dtype | None = mp_policy.param_dtype self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype + self.full_dtensor = full_dtensor def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -210,6 +212,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" if non_dp_mesh_dims > 0: + if self.full_dtensor: + raise NotImplementedError( + "full_dtensor not implemented for nD parallelisms" + ) dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -245,7 +251,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype, - ).to_local(grad_placements=self.grad_placements) + ) + + if not self.full_dtensor: + output = output.to_local(grad_placements=self.grad_placements) else: raise AssertionError( f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" @@ -274,6 +283,7 @@ def data_parallel( mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, reduction_divide_factor: float | None = None, + full_dtensor: bool = False, ) -> nn.Module: param_sharding: tuple[Placement, ...] if mode == "replicate": @@ -333,6 +343,7 @@ def data_parallel( mode, mp_policy=mp_policy, reduction_divide_factor=reduction_divide_factor, + full_dtensor=full_dtensor, ), ) return model From 20fcfd715fdb189beed651c777236461376de253 Mon Sep 17 00:00:00 2001 From: Tushar Jain <8455015+tushar00jain@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:45:31 -0500 Subject: [PATCH 006/127] set pg names (#1986) Summary: - we need to pass the global rank information to pytorch so that the pg name can include the pg information - this is necessary to differentiate the default pg's on different replicas - these need to different because flight recorder matches collectives based on pg name as well - add ft training to experiments folder, we'll move remaining pieces of ft to this gradually but make new features only available through this folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986). * #1988 * #1987 * __->__ #1986 Co-authored-by: Tushar Jain --- torchtitan/distributed/utils.py | 6 +++- torchtitan/experiments/ft/train.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 torchtitan/experiments/ft/train.py diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index f424276a3c..b209ddfd68 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -259,7 +259,10 @@ def maybe_enable_amp( def init_distributed( - comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "" + comm_config: CommConfig, + enable_cpu_backend: bool = False, + base_folder: str = "", + ranks: list[int] | None = None, ): def _warn_overwrite_env(env, val): if env in os.environ: @@ -303,6 +306,7 @@ def _get_distributed_backend(enable_cpu_backend): torch.distributed.init_process_group( backend=_get_distributed_backend(enable_cpu_backend), timeout=timedelta(seconds=comm_config.init_timeout_seconds), + _ranks=ranks if ranks is not None else [], ) diff --git a/torchtitan/experiments/ft/train.py b/torchtitan/experiments/ft/train.py new file mode 100644 index 0000000000..891f6c5554 --- /dev/null +++ b/torchtitan/experiments/ft/train.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.train import main, Trainer + + +class FTTrainer(Trainer): + def init_distributed(self) -> ParallelDims: + job_config = self.job_config + + # determine the global ranks when fault tolerance is enabled + global_ranks = [] + ft_config = job_config.fault_tolerance + if ft_config.enable: + group_size = ft_config.group_size + replica_id = ft_config.replica_id + first_rank = replica_id * group_size + last_rank = first_rank + group_size - 1 + global_ranks = list(range(first_rank, last_rank + 1)) + + # init distributed and build meshes + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=job_config.training.enable_cpu_offload, + base_folder=job_config.job.dump_folder, + ranks=global_ranks, + ) + + world_size = int(os.environ["WORLD_SIZE"]) + parallelism_config = job_config.parallelism + + return ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, + world_size=world_size, + ) + + +if __name__ == "__main__": + main(FTTrainer) From 11d73a2b91981434e52c77aaafb378122f2ada3f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Nov 2025 09:15:49 -0800 Subject: [PATCH 007/127] Fix the error message of maybe_enable_async_tp() (#2011) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2012 * __->__ #2011 It is not correct as JobConfig has changed. --- torchtitan/distributed/tensor_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py index a2749f4c11..04e4e36c3a 100644 --- a/torchtitan/distributed/tensor_parallel.py +++ b/torchtitan/distributed/tensor_parallel.py @@ -17,7 +17,9 @@ def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh): return if not (job_config.compile.enable and "model" in job_config.compile.components): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError( + "Async TP requires 'model' in --compile.components and --compile.enable" + ) from torch.distributed._symmetric_memory import enable_symm_mem_for_group From f5d2b18b3b51bc13862f6b69ea60b55d64bf3097 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 11 Nov 2025 09:44:28 -0800 Subject: [PATCH 008/127] Add dry run mode (#2012) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2012 * #2011 Summary: The current configuration validation requires torchx and GPUs. It can waste time, resources, ane engery. Polar bears are crying. Let's fix this by providing a dry run mode. This PR doesn't verify everything. In theory, we should be able to verify parallelisms settings as well. This PR is just a start but it at least can let us catch the typos quickly. --- run_train.sh | 19 ++++-- scripts/dry_run.py | 156 ++++++++++++++++++++++++++++++++++++++++++++ torchtitan/train.py | 4 +- 3 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 scripts/dry_run.py diff --git a/run_train.sh b/run_train.sh index 8aaf55de28..83319816fe 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,15 +10,24 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh +# DRY_RUN=1 ./run_train.sh # for config validation without GPU NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} +DRY_RUN=${DRY_RUN:-0} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_ALLOC_CONF="expandable_segments:True" \ -TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ -torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ ---local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ --m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" +if [ "$DRY_RUN" = "1" ]; then + # Dry run mode: validate configuration without GPU/distributed setup + echo "Running in DRY RUN mode - configuration validation only" + python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" +else + # Normal training with torchrun + PYTORCH_ALLOC_CONF="expandable_segments:True" \ + TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ + torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" +fi diff --git a/scripts/dry_run.py b/scripts/dry_run.py new file mode 100644 index 0000000000..2552ca0d78 --- /dev/null +++ b/scripts/dry_run.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Dry run trainer for fast configuration validation without GPU/distributed setup. + +This module provides a lightweight trainer that validates job configurations, +model architecture, and dataloader setup without requiring GPU resources or +distributed initialization. Useful for rapid iteration on configuration files +and CI/CD validation pipelines. +""" + +import os +import sys + +# Add parent directory to path to import torchtitan +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.tools import utils +from torchtitan.tools.logging import logger +from torchtitan.train import main, Trainer + + +class DryRunTrainer(Trainer): + """ + A lightweight trainer that validates configurations without GPU allocation. + + This trainer performs comprehensive validation of the training configuration + without allocating GPU resources or initializing distributed setup. It validates: + + - Configuration file parsing and structure + - Model architecture (constructed on meta device) + - Tokenizer initialization + - Dataloader configuration + - Parallelism settings + - Model converters (if specified) + + Unlike the regular Trainer, this does not: + - Allocate GPU memory + - Initialize distributed process groups + - Create optimizers or learning rate schedulers + - Set up checkpointing or metrics + - Run any actual training + + Args: + job_config: JobConfig containing all training configuration parameters + + Note: + Validation completes immediately after initialization. No training loop is executed. + All operations use CPU and meta devices for zero-cost validation. + """ + + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.dry_run") + + self.job_config = job_config + + logger.info(f"Starting job: {job_config.job.description}") + logger.info("DRY RUN MODE - Configuration validation only") + + # Use CPU device (no GPU required) + self.device = torch.device("cpu") + + # Log and validate config + job_config.maybe_log() + logger.info("Configuration parsed successfully") + + # Get train spec + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + logger.info(f"Train spec loaded for model: {job_config.model.name}") + + # Build tokenizer + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + if self.tokenizer: + logger.info("Tokenizer built successfully") + + # Validate model configuration + model_args = self.train_spec.model_args[job_config.model.flavor] + model_args.update_from_config(job_config) + self.model_args = model_args + + logger.info( + f"Model args validated: {job_config.model.name} {job_config.model.flavor}" + ) + + # Build model on meta device (validates architecture without memory allocation) + logger.info("Validating model architecture...") + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + # Calculate and log model size + model_param_count, _ = model_args.get_nparams_and_flops( + model, job_config.training.seq_len + ) + logger.info( + f"Model architecture validated: {job_config.model.name} " + f"with {model_param_count:,} parameters" + ) + + # Validate dataloader configuration (build with minimal params) + logger.info("Validating dataloader configuration...") + try: + # Use dp_world_size=1 and dp_rank=0 for dry run + dataloader = self.train_spec.build_dataloader_fn( + dp_world_size=1, + dp_rank=0, + tokenizer=self.tokenizer, + job_config=job_config, + ) + logger.info("Dataloader configuration validated successfully") + except Exception as e: + logger.warning(f"Dataloader validation encountered issue: {e}") + logger.info( + "Note: Some dataloader issues may only appear with actual data paths" + ) + + # Validate model converters if specified + if job_config.model.converters: + logger.info(f"Model converters specified: {job_config.model.converters}") + + # Validate parallelism configuration + parallelism_config = job_config.parallelism + logger.info( + f"Parallelism config: " + f"DP-shard={parallelism_config.data_parallel_shard_degree}, " + f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, " + f"TP={parallelism_config.tensor_parallel_degree}, " + f"PP={parallelism_config.pipeline_parallel_degree}, " + f"CP={parallelism_config.context_parallel_degree}" + ) + + # Summary + logger.info("=" * 80) + logger.info("DRY RUN VALIDATION COMPLETE") + logger.info("=" * 80) + logger.info("All configurations validated successfully!") + logger.info("Configuration is ready for training execution.") + logger.info("=" * 80) + + +if __name__ == "__main__": + main(DryRunTrainer) diff --git a/torchtitan/train.py b/torchtitan/train.py index 0070806e94..18a876c4bb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -698,9 +698,9 @@ def load_state_dict(self, state_dict: dict[str, Any]): self.ntokens_seen = state_dict["ntokens_seen"] def close(self) -> None: - if self.checkpointer: + if hasattr(self, "checkpointer") and self.checkpointer: self.checkpointer.close() - if self.metrics_processor: + if hasattr(self, "metrics_processor") and self.metrics_processor: self.metrics_processor.close() From edbf3491d1b02bb36f4b5aca07283e0804802459 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Tue, 11 Nov 2025 12:41:16 -0800 Subject: [PATCH 009/127] [easy] [compiler toolkit] Clean up unused function (#2014) As titled. `_clear_traced_params_buffers` is no longer being used as we have switched the dynamo graph capture API. --- .../experiments/compiler_toolkit/graph_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index dbe0a8a257..aee089cad9 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -21,19 +21,6 @@ from torchtitan.tools.logging import logger -def _clear_traced_params_buffers( - traced_module: torch.fx.GraphModule, const_keys: list[str] -) -> None: - """Remove all parameters and buffers from traced module before restoring.""" - for key in const_keys: - assert key in traced_module._buffers.keys() - # We don't want constants to show up as a buffer in the state dict. - # Instead they should just be a direct attribute. - buffer = getattr(traced_module, key) - torch.fx.graph_module._del_attr(traced_module, key) - setattr(traced_module, key, buffer) - - def export_joint( model, args, kwargs=None ) -> tuple[JointWithDescriptors, TracingContext]: From 2f9b44da8eea448d7194eb180331c0734133961b Mon Sep 17 00:00:00 2001 From: akashveramd Date: Tue, 11 Nov 2025 17:10:43 -0800 Subject: [PATCH 010/127] Run Torchtitan ROCm workflow on cron schedule & push to Main branch only (#2016) Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment https://github.com/pytorch/torchtitan/pull/1786#discussion_r2476279289 --- .github/workflows/integration_test_8gpu_features.yaml | 7 +++++-- tests/integration_tests/run_tests.py | 11 +++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index c6e8ed30d5..e97d22c3b7 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -26,6 +26,10 @@ permissions: jobs: build-test: + if: | + matrix.gpu-arch-type == 'cuda' || + (matrix.gpu-arch-type == 'rocm' && + (github.event_name == 'push' && github.ref == 'refs/heads/main' || github.event_name == 'schedule')) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: fail-fast: false @@ -73,8 +77,7 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - export TEST_WITH_ROCM=$([[ "${{ matrix.gpu-arch-type }}" == "rocm" ]] && echo 1 || echo 0) - python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index 011fa25554..b2cb8ea503 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -25,9 +25,6 @@ } -TEST_WITH_ROCM = os.getenv("TEST_WITH_ROCM", "0") == "1" - - def _run_cmd(cmd): return subprocess.run([cmd], text=True, shell=True) @@ -92,7 +89,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if TEST_WITH_ROCM and test_flavor.skip_rocm_test: + if args.gpu_arch_type == "rocm" and test_flavor.skip_rocm_test: continue # Check if we have enough GPUs @@ -110,6 +107,12 @@ def main(): parser.add_argument( "output_dir", help="Directory to dump results generated by tests" ) + parser.add_argument( + "--gpu_arch_type", + default="cuda", + choices=["cuda", "rocm"], + help="GPU architecture type. Must be specified as either 'cuda' or 'rocm'.", + ) parser.add_argument( "--test_suite", default="features", From 55c63c14594107363b8e286c1742efe3efbeda7c Mon Sep 17 00:00:00 2001 From: akashveramd Date: Tue, 11 Nov 2025 18:56:57 -0800 Subject: [PATCH 011/127] Revert PR-2016 & Redo "Run Torchtitan ROCm workflow on cron schedule & push to Main branch only" (#2017) Reverts PR: https://github.com/pytorch/torchtitan/pull/2016 Addressing following issues in this PR- - Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. - Refactor Torchtitan test run to address older PR comment https://github.com/pytorch/torchtitan/pull/1786#discussion_r2476279289 Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> --- .github/workflows/integration_test_8gpu_features.yaml | 7 ++----- tests/integration_tests/run_tests.py | 11 ++++------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index e97d22c3b7..c6e8ed30d5 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -26,10 +26,6 @@ permissions: jobs: build-test: - if: | - matrix.gpu-arch-type == 'cuda' || - (matrix.gpu-arch-type == 'rocm' && - (github.event_name == 'push' && github.ref == 'refs/heads/main' || github.event_name == 'schedule')) uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: fail-fast: false @@ -77,7 +73,8 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + export TEST_WITH_ROCM=$([[ "${{ matrix.gpu-arch-type }}" == "rocm" ]] && echo 1 || echo 0) + python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index b2cb8ea503..011fa25554 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -25,6 +25,9 @@ } +TEST_WITH_ROCM = os.getenv("TEST_WITH_ROCM", "0") == "1" + + def _run_cmd(cmd): return subprocess.run([cmd], text=True, shell=True) @@ -89,7 +92,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if args.gpu_arch_type == "rocm" and test_flavor.skip_rocm_test: + if TEST_WITH_ROCM and test_flavor.skip_rocm_test: continue # Check if we have enough GPUs @@ -107,12 +110,6 @@ def main(): parser.add_argument( "output_dir", help="Directory to dump results generated by tests" ) - parser.add_argument( - "--gpu_arch_type", - default="cuda", - choices=["cuda", "rocm"], - help="GPU architecture type. Must be specified as either 'cuda' or 'rocm'.", - ) parser.add_argument( "--test_suite", default="features", From cbfb8e1def4942294228a2b74d1d16c2cc8aa061 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Wed, 12 Nov 2025 11:17:39 -0800 Subject: [PATCH 012/127] [compiler toolkit] Add tests and scripts for numerics check (#2015) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds the utils to automatically check the training numerics (losses, grad norms) of two runs to verify if they have bitwise equivalence. The added script triggers two runs with user defined configs. Then it loads metrics saved during training and compare the numerics to verify bitwise equivalence. Currently we check for losses and grad norms during training steps For example, we want to compare the numerics between compiler toolkit with aot_eager backend and eager on llama3-8B. ``` python torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py --ngpu 4 --config-file torchtitan/models/llama3/train_configs/llama3_8b.toml --dp-shard-degree 2 --tp-degree 2 ``` It'll run `simple_fsdp` experiment without `torch.compile` as the eager baseline, and `compile_toolkit` experiment as the compiled run. Then it compares the training numerics of these two runs to verify bitwise equivalence. When it is bitwise equivalent, we'll see the following output ``` Starting training: simple_fsdp.llama3 ✓ Training completed: simple_fsdp.llama3 Starting training: compiler_toolkit.llama3 ✓ Training completed: compiler_toolkit.llama3 ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ PASS: All 11 steps match exactly (bitwise equivalent) ✓ SUCCESS: All metrics are bitwise equivalent ``` Also added unit-tests in `compiler_toolkit/tests/test_numerics.py` so that we can guard working parallelism combinations that already have bitwise equivalence in CI. --- .../scripts/check_numerics.py | 126 ++++++++ .../compiler_toolkit/tests/numerics_utils.py | 270 ++++++++++++++++++ .../compiler_toolkit/tests/test_numerics.py | 71 +++++ 3 files changed, 467 insertions(+) create mode 100644 torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py create mode 100644 torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py create mode 100644 torchtitan/experiments/compiler_toolkit/tests/test_numerics.py diff --git a/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py new file mode 100644 index 0000000000..06c1717957 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path + +# Add parent directory to path to import numerics_utils +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from tests.numerics_utils import run_numerics_test + + +def main(): + parser = argparse.ArgumentParser( + description="Run two training jobs and compare their tensorboard metrics" + ) + parser.add_argument( + "--ngpu", + type=int, + required=True, + help="Number of GPUs to use", + ) + parser.add_argument( + "--config-file", + type=str, + required=True, + help="Path to config file", + ) + parser.add_argument( + "--dp-shard-degree", + type=int, + default=1, + help="Data parallel shard degree", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=1, + help="Tensor parallel degree", + ) + parser.add_argument( + "--cp-degree", + type=int, + default=1, + help="Context parallel degree", + ) + parser.add_argument( + "--ep-degree", + type=int, + default=1, + help="Expert parallel degree", + ) + parser.add_argument( + "--ac-mode", + type=str, + default="selective", + choices=["selective", "none", "full"], + help="Activation checkpoint mode", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of training steps", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic training", + ) + parser.add_argument( + "--eager-tb-folder", + type=str, + default="tb/eager_run", + help="Tensorboard folder for eager run", + ) + parser.add_argument( + "--compiled-tb-folder", + type=str, + default="tb/compiled_run", + help="Tensorboard folder for compiled run", + ) + parser.add_argument( + "--metrics", + nargs="+", + default=["loss_metrics/global_avg_loss", "grad_norm"], + help="Metrics to compare", + ) + parser.add_argument( + "--passes", + type=str, + default=None, + help=( + "Comma-separated list of compiler passes to apply " + "(e.g., 'autobucketing_reordering' or 'autobucketing_reordering,regional_inductor')" + ), + ) + + args = parser.parse_args() + + success = run_numerics_test( + ngpu=args.ngpu, + config_file=args.config_file, + dp_shard_degree=args.dp_shard_degree, + tp_degree=args.tp_degree, + cp_degree=args.cp_degree, + ep_degree=args.ep_degree, + ac_mode=args.ac_mode, + steps=args.steps, + seed=args.seed, + eager_tb_folder=args.eager_tb_folder, + compiled_tb_folder=args.compiled_tb_folder, + metrics=args.metrics, + passes=args.passes, + ) + + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py new file mode 100644 index 0000000000..0d7741b1a2 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for numerics testing.""" + +import glob +import os +import subprocess + +import torch +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def load_metrics(event_path, metric_names): + """Load metrics from tensorboard event files.""" + event_acc = EventAccumulator(event_path) + event_acc.Reload() + + metrics = {} + for metric_name in metric_names: + try: + scalars = event_acc.Scalars(metric_name) + metrics[metric_name] = {scalar.step: scalar.value for scalar in scalars} + except KeyError: + print(f"Warning: Metric {metric_name!r} not found in event file") + metrics[metric_name] = {} + + return metrics + + +def compare_metrics(metrics1, metrics2, label1="Eager", label2="Compiled"): + """Compare two sets of metrics and verify bitwise equivalence using torch.equal().""" + + all_metrics = set(metrics1.keys()) | set(metrics2.keys()) + all_match = True + + for metric_name in sorted(all_metrics): + + steps1 = set(metrics1[metric_name].keys()) + steps2 = set(metrics2[metric_name].keys()) + + if steps1 != steps2: + print(" ERROR: Step mismatch!") + print(f" {label1} steps: {sorted(steps1)}") + print(f" {label2} steps: {sorted(steps2)}") + all_match = False + continue + + # Convert values to tensors for each step and compare + values1 = [metrics1[metric_name][step] for step in sorted(steps1)] + values2 = [metrics2[metric_name][step] for step in sorted(steps2)] + + tensor1 = torch.tensor(values1) + tensor2 = torch.tensor(values2) + + if torch.equal(tensor1, tensor2): + print( + f" ✓ PASS: All {len(steps1)} steps match exactly (bitwise equivalent)" + ) + else: + # Find and report mismatches + mismatches = [] + for idx, step in enumerate(sorted(steps1)): + val1 = values1[idx] + val2 = values2[idx] + if val1 != val2: + mismatches.append((step, val1, val2, abs(val1 - val2))) + + print( + f" ERROR: Found {len(mismatches)} mismatches out of {len(steps1)} steps" + ) + + return all_match + + +def find_latest_event_dir(base_path): + """Find the latest timestamped directory in the base path.""" + if not os.path.exists(base_path): + raise ValueError(f"Path does not exist: {base_path}") + + subdirs = [d for d in glob.glob(os.path.join(base_path, "*")) if os.path.isdir(d)] + if not subdirs: + return base_path + + latest = max(subdirs, key=os.path.getmtime) + return latest + + +def run_training( + ngpu, + config_file, + model_name, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + deterministic, + tb_folder, + passes=None, +): + """Run a training job with the specified configuration.""" + print(f"\nStarting training: {model_name}") + + env = os.environ.copy() + env["NGPU"] = str(ngpu) + env["CONFIG_FILE"] = config_file + + cmd = [ + "./run_train.sh", + "--model.name", + model_name, + "--parallelism.data_parallel_shard_degree", + str(dp_shard_degree), + "--parallelism.tensor_parallel_degree", + str(tp_degree), + ] + + if cp_degree > 1: + cmd.extend(["--parallelism.context_parallel_degree", str(cp_degree)]) + if ep_degree > 1: + cmd.extend(["--parallelism.expert_parallel_degree", str(ep_degree)]) + + cmd.extend( + [ + "--activation_checkpoint.mode", + ac_mode, + "--training.steps", + str(steps), + "--debug.seed", + str(seed), + "--debug.deterministic", + "--metrics.enable_tensorboard", + "--metrics.save_tb_folder", + tb_folder, + ] + ) + + if passes: + cmd.extend( + [ + "--job.custom_config_module", + "torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes", + passes, + ] + ) + + print(f"Environment: NGPU={env['NGPU']}, CONFIG_FILE={env['CONFIG_FILE']}") + print(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + env=env, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print(f"✓ Training completed: {model_name}") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Training failed: {model_name}") + print(f"Error output:\n{e.stdout}") + return False + + +def determine_model_names(config_file): + """Determine model names based on config file.""" + if "deepseek" in config_file: + model_name = "deepseek_v3" + elif "llama3" in config_file: + model_name = "llama3" + else: + raise ValueError( + f"Unable to determine model names from config file: {config_file}" + ) + + eager_model = f"simple_fsdp.{model_name}" + compiled_model = f"compiler_toolkit.{model_name}" + + return eager_model, compiled_model + + +def run_numerics_test( + ngpu, + config_file, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + eager_tb_folder, + compiled_tb_folder, + metrics, + passes=None, +): + """ + Run numerics test by training both eager and compiled models and comparing metrics. + + Returns: + bool: True if all metrics match, False otherwise. + """ + # Determine model names + eager_model, compiled_model = determine_model_names(config_file) + + # Run eager training + eager_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=eager_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=eager_tb_folder, + ) + + if not eager_success: + print("✗ Eager training failed") + return False + + # Run compiled training + compiled_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=compiled_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=compiled_tb_folder, + passes=passes, + ) + + if not compiled_success: + print("✗ Compiled training failed") + return False + + # Compare metrics + eager_path = find_latest_event_dir(f"./outputs/{eager_tb_folder}") + compiled_path = find_latest_event_dir(f"./outputs/{compiled_tb_folder}") + + eager_metrics = load_metrics(eager_path, metrics) + compiled_metrics = load_metrics(compiled_path, metrics) + + all_match = compare_metrics(eager_metrics, compiled_metrics) + + if all_match: + print("✓ SUCCESS: All metrics are bitwise equivalent") + else: + print("✗ FAILURE: Metrics differ between runs") + + return all_match diff --git a/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py new file mode 100644 index 0000000000..3bf5650e55 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +from .numerics_utils import run_numerics_test + + +class TestNumerics(unittest.TestCase): + """Test numerics equivalence between simple_fsdp and compiler_toolkit implementations.""" + + def test_llama3_fsdp_tp(self): + """Test Llama3 with FSDP + TP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "Llama3 FSDP+TP numerics test failed") + + def test_llama3_fsdp_tp_autobucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="autobucketing_reordering", + ) + + def test_deepseek_v3_fsdp_tp_ep(self): + """Test DeepSeek V3 with FSDP + TP + EP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=4, + ac_mode="none", + steps=10, + seed=42, + eager_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_eager", + compiled_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "DeepSeek V3 FSDP+TP+EP numerics test failed") + + +if __name__ == "__main__": + unittest.main() From 4b2b31ccd2fff8f167e58243a58cf57fb0847011 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 23:46:22 -0800 Subject: [PATCH 013/127] Add .claude to .gitignore (#2026) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2029 * #2030 * #2028 * #2027 * __->__ #2026 As title --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 45a8f5752a..415631ff9c 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ Sessionx.vim # env files .env + +# Vibe coding +.claude From ce1c0fc26b53525445e88eb006eb694f86e53a52 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 23:46:50 -0800 Subject: [PATCH 014/127] Fix dry run mode (#2027) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2029 * #2030 * #2028 * __->__ #2027 * #2026 Dry run mode works but it doesn't exit gracefully for all cases. This PR fixes it ``` DRY_RUN=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --activation_checkpoint.mode="none" --debug.deterministic --debug.seed=42 ``` --- scripts/dry_run.py | 3 +++ torchtitan/train.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/dry_run.py b/scripts/dry_run.py index 2552ca0d78..fa8e1b4c17 100644 --- a/scripts/dry_run.py +++ b/scripts/dry_run.py @@ -151,6 +151,9 @@ def __init__(self, job_config: JobConfig): logger.info("Configuration is ready for training execution.") logger.info("=" * 80) + def train(self): + return + if __name__ == "__main__": main(DryRunTrainer) diff --git a/torchtitan/train.py b/torchtitan/train.py index 18a876c4bb..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -735,7 +735,8 @@ def main(trainer_class: type[Trainer]) -> None: raise else: trainer.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") From e7ee95ab441faa6987a0101180f46d8415ac3832 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 13 Nov 2025 00:39:50 -0800 Subject: [PATCH 015/127] [Compiler Toolkit] Make compiler toolkit work with checkpoint (#2030) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2029 * __->__ #2030 The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. Also merge https://github.com/pytorch/torchtitan/pull/2028 to this PR. Something wrong with ghstack. --- .../deepseek_v3/parallelize.py | 5 +- .../compiler_toolkit/graph_utils.py | 65 +++++++++++++++---- .../compiler_toolkit/llama3/parallelize.py | 5 +- 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index bc6859af61..20ad17f301 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,7 +80,9 @@ def parallelize_deepseekv3( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( @@ -88,6 +90,7 @@ def parallelize_deepseekv3( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index aee089cad9..62c19cb3de 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Callable, List, Optional +from pathlib import Path +from typing import Any, Callable, List, Optional import torch from torch._dynamo.functional_export import dynamo_graph_capture_for_export @@ -21,8 +22,18 @@ from torchtitan.tools.logging import logger +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(gm.print_readable(print_output=False)) + + def export_joint( - model, args, kwargs=None + model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: if kwargs is None: kwargs = {} @@ -35,8 +46,10 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.debug("Dynamo gm:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, "dynamo_gm") + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -68,6 +81,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_pass: Optional[Callable] = None, + dump_folder: str | None = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -79,16 +93,17 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_pass: Optional custom pass to run on the joint graph + dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs) + ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation if joint_custom_pass is not None: @@ -153,6 +168,18 @@ def __delattr__(self, name: str) -> None: else: super().__delattr__(name) + def state_dict(self, *args, **kwargs) -> Any: + return self.inner.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs) -> Any: + return self.inner.load_state_dict(*args, **kwargs) + + def name_parameters(self, *args, **kwargs) -> Any: + return self.inner.named_parameters(*args, **kwargs) + + def parameters(self, *args, **kwargs) -> Any: + return self.inner.parameters(*args, **kwargs) + def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" @@ -179,6 +206,7 @@ def compiler( gm: torch.fx.GraphModule, example_inputs, passes: List[Callable] = None, + dump_folder: str | None = None, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -190,23 +218,28 @@ def compiler( passes: List of compiler pass functions to apply. Each function should take (gm, example_inputs) and return a transformed gm. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump the graph to """ if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} before compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} after compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm -def make_compiler_with_passes(passes: List[Callable] = None): +def make_compiler_with_passes( + passes: List[Callable] = None, dump_folder: str | None = None +): """ Create forward and backward compilers with specified passes. @@ -218,10 +251,14 @@ def make_compiler_with_passes(passes: List[Callable] = None): """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs, passes=passes) + return compiler( + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs, passes=passes) + return compiler( + "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e3dca203e9..62def3ef00 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -67,7 +67,9 @@ def parallelize_llama( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( @@ -75,6 +77,7 @@ def parallelize_llama( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can From 23c993ca3f61da6916a4f480581504f00d6a4527 Mon Sep 17 00:00:00 2001 From: Riccardo Mereu Date: Thu, 13 Nov 2025 21:55:10 +0200 Subject: [PATCH 016/127] [Flux] Update integration test badge in README.md (#2019) Fixes the badge in the `README.md` file --- torchtitan/models/flux/README.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchtitan/models/flux/README.md b/torchtitan/models/flux/README.md index 2498d1a346..aa83b845db 100644 --- a/torchtitan/models/flux/README.md +++ b/torchtitan/models/flux/README.md @@ -1,11 +1,5 @@ -
- # FLUX model in torchtitan -[![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main) - -
- ## Overview This directory contains the implementation of the [FLUX](https://github.com/black-forest-labs/flux/tree/main) model in torchtitan. In torchtitan, we showcase the pre-training process of text-to-image part of the FLUX model. From 028a455c1e3c0ea5aeb34e04f6ae25401b900bfc Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:51:59 -0800 Subject: [PATCH 017/127] Print device and stride when print module (#2045) Before: image After: image --- .../compiler_toolkit/graph_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 62c19cb3de..cd758438b3 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -29,7 +29,9 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No output_path = Path(dump_folder) / "compiler" / f"{name}.txt" output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(gm.print_readable(print_output=False)) + output_path.write_text( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) def export_joint( @@ -47,7 +49,11 @@ def export_joint( ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.debug("Dynamo gm:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) _dump_gm(dump_folder, gm, "dynamo_gm") tracing_context = gm.meta["tracing_context"] @@ -224,7 +230,9 @@ def compiler( passes = DEFAULT_COMPILER_PASSES logger.debug(f"{name} before compiler:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: @@ -232,7 +240,9 @@ def compiler( gm = pass_fn(gm, example_inputs) logger.debug(f"{name} after compiler:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm From d9bdfbb206953b244bce73b316e605cc24d1d9ae Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Sat, 15 Nov 2025 22:58:08 -0800 Subject: [PATCH 018/127] [SimpleFSDP] add manual bucketing pass (#1881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for aten-level manual bucketing in SimpleFSDP+`aot_eager` backend. Dependent on PyTorch [PR](https://github.com/pytorch/pytorch/pull/165487) TODO List: - [ ] We should have better way of handling region info other than a list of str FQNs in current `manual_bucketed_modules`. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad ) - [ ] Currently, the reordering happens under the hood and overlap with last/next compute. We should allow users to specify which module they want to reorder. - [ ] Loss difference on multi-node training - [ ] DSV3 manual bucketing I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR. 1. Performance (FSDP2 under eager mode, SimpleFSDP uses `aot_eager` backend) **Llama 3-8B** * Performance (All Batch_size = 1). (The slower TPS on Single Node is sort of as expected, since FSDP2 handles copy-in/out in two different streams, whereas SimpleFSDP handles copy-in/out in the same stream) |Node| Method | Parallelism | Memory | TPS | Trace| |---------|---------|-----------|----------|------|------| |1-Node (8H100)|SimpleFSDP | FSDP=8| 40.96GiB(43.12%) | 7,227| [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-48-48_rank0_trace.json)| |1-Node (8H100)|FSDP2-eager| FSDP=8| 47.82GiB(50.35%) | 7,380 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-16-10-54-14_rank0_trace.json)| |8-Node (64H100)|SimpleFSDP| FSDP=64 | 29.37GiB | 4,984| | |8-Node (64H100)|FSDP2| FSDP=64 | 31.41GiB |5,097 | | |1-Node (8H100)|SimpleFSDP| FSDP=4 TP=2 | 28.28GiB(29.77%) | 5,881 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-18-00-18_rank0_trace.json) | |1-Node (8H100)|FSDP2| FSDP=4 TP=2 | 35.33GiB(37.20%) | 5,898 | [LINK](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-26-15-35-47_rank0_trace.json) | |8-Node (64H100)|SimpleFSDP| FSDP=8 TP=8 | ||| |8-Node (64H100)|FSDP2| FSDP=8 TP=8 | ||| Example SimpleFSDP 1D overlapping trace: Screenshot 2025-10-16 at 10 49
55 AM Example SimpleFSDP 2D overlapping trace: Screenshot 2025-10-26 at 6 00 51 PM - Bitwise Loss: FSDP-only: Screenshot 2025-10-17 at 10 41
56 AM FSDP+TP: Screenshot 2025-10-26 at 9 03 58 PM --- torchtitan/experiments/simple_fsdp/README.md | 18 +-- torchtitan/experiments/simple_fsdp/backend.py | 135 ++++++++++++++---- .../simple_fsdp/deepseek_v3/parallelize.py | 36 ++++- .../experiments/simple_fsdp/job_config.py | 8 +- .../simple_fsdp/llama3/parallelize.py | 36 ++++- .../simple_fsdp/tests/integration_tests.py | 20 ++- 6 files changed, 204 insertions(+), 49 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index ea4fb3272f..a1d40cf2b1 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -52,14 +52,16 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing 1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager") 2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance** - - "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. - - -users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs: - -```bash ---compile.model_backend_override "aot_eager_autobucketing" -``` + - "auto_bucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.graph_passes "auto_bucketing" + ``` + +3. manual optimization: perform manual bucketing & reordering with user FQN inputs. + - "transformer_block_bucketing": perform bucketing by transformer blocks at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.graph_passes "transformer_block_bucketing" + ``` ### Citation diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index d51e6668c1..7fc9d13bf4 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -8,49 +8,132 @@ import torch import torch._functorch.config as functorch_config +from torchtitan.tools.logging import logger + +from .job_config import Compile as CompileConfig from .reshard_after_forward import annotate_fsdp_all_gather -def get_compile_backend( - backend_name: str, fsdp_reshard_after_forward: bool +def get_compile_backend_with_passes( + compile_config: CompileConfig, + fsdp_reshard_after_forward: bool, + fsdp_manual_buckets: list[list[str] | str] | None, ) -> callable: - # return the compile backends used in SimpleFSDP training - # Step1: check if backend_name is inside available torch.compile backends - # Step2: check if the backend_name has been registered as a customized backend - available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) - - if backend_name in available_torch_backend: - backend = torch._dynamo.lookup_backend(backend_name) - elif backend_name == "aot_eager_autobucketing": - # Perform auto optimization in aten fx-level and execute code in aot_eager backend - # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 - from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + """ + Apply compile backend and additional graph passes. + Args: + compile_config: compile configs to apply torch.compile. + fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP, + which is implemented via a customized AC graph pass. + fsdp_manual_buckets: used in transformer_block_bucketing to define which modules should be bucketed. + Returns: + compile backend with applied graph passes. + """ + backend = torch._dynamo.lookup_backend(compile_config.backend) + # Apply bucketing and overlapping pass on fwd and bwd graph separately + if compile_config.graph_passes == "auto_bucketing": + # Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend + # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) dist_opts.collective_bucketing = True - dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False - def aten_autobucketing_reordering_pass( - gm: torch.fx.GraphModule, example_inputs: Any - ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) - gm.recompile() - return gm - - backend = aot_autograd_backend( - fw_compiler=aten_autobucketing_reordering_pass, - bw_compiler=aten_autobucketing_reordering_pass, - keep_inference_input_mutations=True, + if compile_config.backend == "aot_eager": + from torch._dynamo.backends.common import ( + aot_autograd as aot_autograd_backend, + ) + + def aot_eager_autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + schedule_overlap_bucketing(gm) + gm.recompile() + return gm + + dist_opts.insert_overlap_deps = False + backend = aot_autograd_backend( + fw_compiler=aot_eager_autobucketing_reordering_pass, + bw_compiler=aot_eager_autobucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_autobucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return schedule_overlap_bucketing(gm.owning_module) + + dist_opts.insert_overlap_deps = True + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_autobucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for auto_bucketing pass" + ) + logger.info("Auto bucketing pass is applied") + + elif compile_config.graph_passes == "transformer_block_bucketing": + # Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend + # The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487 + from functools import partial + + from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + from torch._inductor.fx_passes.overlap_manual_scheduling import ( + manual_overlap_bucketing, ) + + torch._inductor.config.allow_buffer_reuse = False + manual_overlap_bucketing = partial( + manual_overlap_bucketing, + module_bucket_plans=fsdp_manual_buckets, + ) + + if compile_config.backend == "aot_eager": + + def aot_eager_transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + manual_overlap_bucketing(gm, insert_overlap_deps=False) + return gm + + backend = aot_autograd_backend( + fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_transformer_block_bucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return manual_overlap_bucketing( + gm.owning_module, insert_overlap_deps=True + ) + + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_transformer_block_bucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" + ) + logger.info("Transformer block bucketing pass is applied") + else: - raise AssertionError(f"Unsupported customized backend: {backend_name}") + logger.info("No bucketing or overlapping pass is applied") + # Apply activation checkpointing on joint graph before partitioner def joint_ac_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 2ae1c517f3..6d415004cc 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -19,10 +19,35 @@ ) from torchtitan.tools.logging import logger -from ..backend import get_compile_backend +from ..backend import get_compile_backend_with_passes from ..simple_fsdp import data_parallel, MixedPrecisionPolicy + +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + # [TODO](ruisizhang123) add EP support for transformer block bucketing + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + else: + result.append(module_to_fqn_mapping.get(m, None)) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( model: nn.Module, @@ -177,13 +202,14 @@ def parallelize_deepseekv3( f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." ) - backend = ( - getattr(job_config.compile, "model_backend_override", None) - or job_config.compile.backend + backend = get_compile_backend_with_passes( + job_config.compile, + fsdp_reshard_after_forward, + get_transformer_block_buckets(model), ) model = torch.compile( model, - backend=get_compile_backend(backend, fsdp_reshard_after_forward), + backend=backend, fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/job_config.py b/torchtitan/experiments/simple_fsdp/job_config.py index a7e7c4c22f..f752fa1170 100644 --- a/torchtitan/experiments/simple_fsdp/job_config.py +++ b/torchtitan/experiments/simple_fsdp/job_config.py @@ -5,12 +5,16 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field +from typing import Literal @dataclass class Compile: - model_backend_override: str | None = None - """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" + graph_passes: Literal["auto_bucketing", "transformer_block_bucketing"] | None = None + """ + Bucketing and overlapping passes in simplefsdp. Additional passes include: + auto_bucketing, transformer_block_bucketing + """ @dataclass diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 1d8bfc500f..67a012a3f7 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -14,7 +14,7 @@ from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger -from ..backend import get_compile_backend +from ..backend import get_compile_backend_with_passes from ..simple_fsdp import data_parallel, MixedPrecisionPolicy @@ -33,6 +33,31 @@ } +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): + result.append(fqn_list) + else: + if fqn := module_to_fqn_mapping.get(m): + result.append(fqn) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -139,13 +164,14 @@ def parallelize_llama( f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." ) - backend = ( - getattr(job_config.compile, "model_backend_override", None) - or job_config.compile.backend + backend = get_compile_backend_with_passes( + job_config.compile, + fsdp_reshard_after_forward, + get_transformer_block_buckets(model), ) model = torch.compile( model, - backend=get_compile_backend(backend, fsdp_reshard_after_forward), + backend=backend, fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index f18ee95528..c3cee7b52f 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -35,11 +35,25 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "--model.name simple_fsdp.llama3", "--compile.enable", "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", - "--compile.model_backend_override aot_eager_autobucketing", + "--compile.backend aot_eager", + "--compile.graph_passes auto_bucketing", ], ], - "1D+aot_eager_autobucketing", - "1d_aot_eager_autobucketing", + "1D+autobucketing", + "1d_autobucketing", + ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.llama3", + "--compile.enable", + "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", + "--compile.backend aot_eager", + "--compile.graph_passes transformer_block_bucketing", + ], + ], + "1D+transformer_block_bucketing", + "1d_transformer_block_bucketing", ), OverrideDefinitions( [ From 22e959a38b588b39db98acd418da28b912c57bc2 Mon Sep 17 00:00:00 2001 From: Ido Hakimi <5303103+idoh@users.noreply.github.com> Date: Mon, 17 Nov 2025 02:31:37 +0200 Subject: [PATCH 019/127] Add export_dtype parameter to `convert_to_hf` function (#2041) The current `convert_to_hf.py` does not support `export_dtype`, which makes it `float32` by default. This PR adds support for export dtypes of `["float16", "bfloat16", "float32"]`. --- .../checkpoint_conversion/convert_to_hf.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index f0ea17cc63..ad13850b82 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -12,12 +12,18 @@ import torchtitan.protocols.train_spec as train_spec_module from torch.distributed.checkpoint import HuggingFaceStorageWriter from torchtitan.components.checkpoint import ModelWrapper +from torchtitan.config import TORCH_DTYPE_MAP @torch.inference_mode() -def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path): - if model_name == "flux": - import torchtitan.experiments.flux # noqa: F401 +def convert_to_hf( + input_dir, + output_dir, + model_name, + model_flavor, + hf_assets_path, + export_dtype, +): # load model and model args so that we can get the state dict shape train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] @@ -49,6 +55,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat thread_count_consolidation=5, ) + # map and apply export dtype if needed + target_dtype = TORCH_DTYPE_MAP[export_dtype] + if target_dtype != torch.float32: + hf_state_dict = {k: v.to(target_dtype) for k, v in hf_state_dict.items()} + dcp.save( hf_state_dict, storage_writer=storage_writer, @@ -71,6 +82,14 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat ) parser.add_argument("--model_name", type=str, nargs="?", default="llama3") parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") + parser.add_argument( + "--export_dtype", + type=str, + nargs="?", + choices=["float16", "bfloat16", "float32"], + default="float32", + help="Export dtype for HF checkpoint (default: float32)", + ) args = parser.parse_args() convert_to_hf( @@ -79,4 +98,5 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat args.model_name, args.model_flavor, args.hf_assets_path, + args.export_dtype, ) From 3819737fab042fdfd5443b1d99753b951b59696d Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Mon, 17 Nov 2025 21:34:40 -0800 Subject: [PATCH 020/127] [compiler toolkit] Port joint_ac_pass from simplefsdp (#2051) This PR integrates the changes in #1970 to compiler toolkit (applying `joint_ac_pass` on the joint graph graph to tag nodes based on `reshard_after_forward` flag) Also did some refactor for applying graph passes in compiler toolkit experiments. We will have two kinds of passes 1. joint_custom_passes: these are passes to be applied on the captured joint graph before partitioner. By default we `validate_flex_attn_annotation_pass` and `fsdp_reshard_after_fwd_pass` 2. compiler_passes: there are passes to be applied on partitioned fwd and bwd graphs as backend optimizations. By default there is none. We can indicate `autobucketing_reordering_pass` and `regional_inductor_pass` using configs. --- .../compiler_toolkit/common_utils.py | 10 --- .../deepseek_v3/parallelize.py | 7 +- .../compiler_toolkit/graph_utils.py | 64 ++++++++++++++++--- .../compiler_toolkit/llama3/parallelize.py | 9 ++- .../experiments/compiler_toolkit/passes.py | 32 +++++++++- 5 files changed, 98 insertions(+), 24 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index b7499b2f79..965e027bdb 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -53,13 +53,3 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) - - -def validate_flex_attention_annotation(joint_with_descriptors): - """Verify user annotations show up in the graph.""" - for node in joint_with_descriptors.graph_module.graph.nodes: - if node.target in { - torch.ops.higher_order.flex_attention, - torch.ops.higher_order.flex_attention_backward, - }: - assert "compile_with_inductor" in node.meta.get("custom", {}) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 20ad17f301..982843bb24 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -17,12 +17,12 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, - validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, ) @@ -76,6 +76,9 @@ def parallelize_deepseekv3( with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(job_config) @@ -89,7 +92,7 @@ def parallelize_deepseekv3( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=validate_flex_attention_annotation, + joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index cd758438b3..fa93b02b63 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import functools from pathlib import Path from typing import Any, Callable, List, Optional @@ -86,7 +87,7 @@ def joint_graph_builder( model_kwargs: dict, fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, - joint_custom_pass: Optional[Callable] = None, + joint_custom_passes: Optional[List[Callable]] = None, dump_folder: str | None = None, ): """ @@ -98,7 +99,7 @@ def joint_graph_builder( model_kwargs: Dict of model input keyword arguments fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function - joint_custom_pass: Optional custom pass to run on the joint graph + joint_custom_passes: list of custom passes to run on the joint graph dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) @@ -112,8 +113,11 @@ def joint_graph_builder( ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation - if joint_custom_pass is not None: - joint_custom_pass(joint_with_descriptors) + if joint_custom_passes is not None: + for joint_custom_pass in joint_custom_passes: + joint_with_descriptors.graph_module = joint_custom_pass( + joint_with_descriptors.graph_module + ) with tracing(tracing_context): fn = aot_compile_joint_with_descriptors( @@ -283,20 +287,64 @@ def get_compiler_passes_from_config(job_config: JobConfig): Returns: List of compiler pass functions """ - from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES + from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES pass_names = getattr(job_config.compile, "passes", []) compiler_passes = [] for pass_name in pass_names: - if pass_name not in AVAILABLE_PASSES: + if pass_name not in AVAILABLE_COMPILER_PASSES: raise ValueError( f"Unknown compiler pass: {pass_name}. " - f"Available passes: {list(AVAILABLE_PASSES.keys())}" + f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}" ) - compiler_passes.append(AVAILABLE_PASSES[pass_name]) + compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) if pass_names: logger.info(f"Using compiler passes from config: {pass_names}") return compiler_passes + + +def get_joint_custom_passes_from_config( + parallel_dims: ParallelDims, job_config: JobConfig +): + """ + Extract and validate joint custom passes from job config. + + Args: + job_config: Job configuration containing parallelism.fsdp_reshard_after_forward + + Returns: + List of joint custom pass functions + """ + from torchtitan.experiments.compiler_toolkit.passes import ( + fsdp_reshard_after_fwd_pass, + validate_flex_attn_annotation_pass, + ) + + joint_custom_passes = [] + joint_custom_passes.append(validate_flex_attn_annotation_pass) + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + joint_custom_passes.append( + functools.partial( + fsdp_reshard_after_fwd_pass, + reshard_after_forward=fsdp_reshard_after_forward, + ) + ) + + return joint_custom_passes diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 62def3ef00..e746c24228 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -16,12 +16,12 @@ disable_compile, parallelize_inputs, register_blockmask_pytree_node, - validate_flex_attention_annotation, ) from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, make_compiler_with_passes, ) @@ -63,6 +63,9 @@ def parallelize_llama( with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(job_config) @@ -71,12 +74,12 @@ def parallelize_llama( compiler_passes, dump_folder=job_config.job.dump_folder ) - # Create custom joint_graph_builder with llama-specific compilers and validation + # Create custom joint_graph_builder with llama-specific compilers llama_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=validate_flex_attention_annotation, + joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 1c00fd5c1b..c0cec614a9 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -14,6 +14,9 @@ import torch from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( + annotate_fsdp_all_gather, +) def autobucketing_reordering_pass( @@ -39,8 +42,35 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) +def validate_flex_attn_annotation_pass( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Verify user annotations show up in the graph.""" + for node in gm.graph.nodes: + if node.target in { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + }: + assert "compile_with_inductor" in node.meta.get("custom", {}) + return gm + + +# Apply activation checkpointing on joint graph before partitioner +def fsdp_reshard_after_fwd_pass( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, reshard_after_forward) + gm.recompile() + return gm + + # Registry mapping pass names to pass functions -AVAILABLE_PASSES = { +AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "regional_inductor": regional_inductor_pass, } From bfdc974f94e033999362e2bf0ee0842457d43470 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:44:12 -0800 Subject: [PATCH 021/127] [compiler toolkit] Port manual bucketing from SimpleFSDP experiment (#2056) This PR integrates the manual bucketing pass (transformer block bucketing) added in SimpleFSDP experiment (#1881) to compiler toolkit So now compiler toolkit can also run manual bucketing pass by specifying the config ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` Also updated README and integration test to include the newly ported pass --- .../experiments/compiler_toolkit/README.md | 11 +++++++ .../deepseek_v3/parallelize.py | 2 +- .../compiler_toolkit/graph_utils.py | 31 ++++++++++++++--- .../compiler_toolkit/llama3/parallelize.py | 2 +- .../experiments/compiler_toolkit/passes.py | 17 +++++++++- .../tests/integration_tests.py | 33 ++++++++++++++++--- .../compiler_toolkit/tests/test_numerics.py | 23 +++++++++++-- 7 files changed, 106 insertions(+), 13 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 61207fc63b..c223d1e658 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -34,6 +34,11 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering ``` +**SimpleFSDP + TP + transformer-block-bucketing** +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing +``` + **SimpleFSDP + TP + FlexAttention** ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn @@ -44,3 +49,9 @@ NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./r ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor** + +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor +``` diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 982843bb24..011bbe402a 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,7 +80,7 @@ def parallelize_deepseekv3( joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config(job_config) + compiler_passes = get_compiler_passes_from_config(model, job_config) # Create compilers with specified passes (defaults to no passes) fw_compiler, bw_compiler = make_compiler_with_passes( diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index fa93b02b63..51ac8ba983 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -112,7 +112,7 @@ def joint_graph_builder( tracing_context, ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) - # Optional validation + # run custom passes on joint-graph before partitioner if joint_custom_passes is not None: for joint_custom_pass in joint_custom_passes: joint_with_descriptors.graph_module = joint_custom_pass( @@ -240,7 +240,12 @@ def compiler( _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: - logger.info(f"Applying pass: {pass_fn.__name__}") + pass_name = ( + pass_fn.func.__name__ + if isinstance(pass_fn, functools.partial) + else pass_fn.__name__ + ) + logger.info(f"Applying pass: {pass_name}") gm = pass_fn(gm, example_inputs) logger.debug(f"{name} after compiler:") @@ -277,7 +282,7 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return fw_compiler, bw_compiler -def get_compiler_passes_from_config(job_config: JobConfig): +def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. @@ -288,8 +293,18 @@ def get_compiler_passes_from_config(job_config: JobConfig): List of compiler pass functions """ from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES + from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( + get_transformer_block_buckets, + ) pass_names = getattr(job_config.compile, "passes", []) + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) compiler_passes = [] for pass_name in pass_names: @@ -298,7 +313,15 @@ def get_compiler_passes_from_config(job_config: JobConfig): f"Unknown compiler pass: {pass_name}. " f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}" ) - compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) + if pass_name == "transformer_block_bucketing": + compiler_passes.append( + functools.partial( + AVAILABLE_COMPILER_PASSES[pass_name], + fsdp_manual_buckets=get_transformer_block_buckets(model), + ) + ) + else: + compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) if pass_names: logger.info(f"Using compiler passes from config: {pass_names}") diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e746c24228..68fa7443f4 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -67,7 +67,7 @@ def parallelize_llama( joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) # Get compiler passes from config - compiler_passes = get_compiler_passes_from_config(job_config) + compiler_passes = get_compiler_passes_from_config(model, job_config) # Create compilers with specified passes (defaults to no passes) fw_compiler, bw_compiler = make_compiler_with_passes( diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index c0cec614a9..64276a91bc 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -12,6 +12,7 @@ """ import torch +from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( @@ -26,13 +27,26 @@ def autobucketing_reordering_pass( Apply autobucketing and reordering optimization. This pass applies schedule_overlap_bucketing with collective_bucketing enabled - to optimize communication patterns in distributed training. + to optimize comm/compute overlap patterns in the graph. """ schedule_overlap_bucketing(gm, collective_bucketing=True) gm.recompile() return gm +def transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs, fsdp_manual_buckets +) -> torch.fx.GraphModule: + """ + Apply aten-level manual bucketing and reordering optimization. + """ + manual_overlap_bucketing( + gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False + ) + gm.recompile() + return gm + + def regional_inductor_pass( gm: torch.fx.GraphModule, example_inputs ) -> torch.fx.GraphModule: @@ -72,5 +86,6 @@ def fsdp_reshard_after_fwd_pass( # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, + "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, "regional_inductor": regional_inductor_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index e33149fe2f..b0155a9f2a 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -24,7 +24,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--model.name compiler_toolkit.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--activation_checkpoint.mode none", ], ], "llama3 FSDP+TP", @@ -37,7 +36,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--model.name compiler_toolkit.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--activation_checkpoint.mode none", "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", "--compile.passes autobucketing_reordering", ], @@ -46,6 +44,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_autobucketing", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes transformer_block_bucketing", + ], + ], + "llama3 FSDP+TP manualbucketing", + "llama3_fsdp_tp_manualbucketing", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -53,7 +65,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", "--model.flavor debugmodel_flex_attn", - "--activation_checkpoint.mode none", ], ], "llama3 FSDP+TP+FlexAttn", @@ -67,7 +78,6 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", "--model.flavor debugmodel_flex_attn", - "--activation_checkpoint.mode none", "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", "--compile.passes autobucketing_reordering,regional_inductor", ], @@ -76,6 +86,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes transformer_block_bucketing,regional_inductor", + ], + ], + "llama3 FSDP+TP+FlexAttn manualbucketing regional_inductor", + "llama3_fsdp_tp_flexattn_manualbucketing_regional_inductor", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ diff --git a/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py index 3bf5650e55..1421ca3bca 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py +++ b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py @@ -42,11 +42,30 @@ def test_llama3_fsdp_tp_autobucketing(self): ac_mode="selective", steps=10, seed=42, - eager_tb_folder="tb/test_llama3_fsdp_tp_eager", - compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + eager_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_compiled", metrics=["loss_metrics/global_avg_loss", "grad_norm"], passes="autobucketing_reordering", ) + self.assertTrue(result, "Llama3 FSDP+TP+autobucketing numerics test failed") + + def test_llama3_fsdp_tp_manualbucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="transformer_block_bucketing", + ) + self.assertTrue(result, "Llama3 FSDP+TP+manualbucketing numerics test failed") def test_deepseek_v3_fsdp_tp_ep(self): """Test DeepSeek V3 with FSDP + TP + EP configuration.""" From 4a5fa9950082505578b32cbc3886b88a37cb1d9e Mon Sep 17 00:00:00 2001 From: akashveramd Date: Wed, 19 Nov 2025 01:10:07 -0800 Subject: [PATCH 022/127] Re:Run Torchtitan ROCm workflow on cron schedule & push to Main branch only (#2018) Addressing following issues in this PR- Running Torchtitan ROCm workflow on cron schedule & only when push to Main branch. CUDA workflow will run as is. Refactor Torchtitan test run to address older PR comment https://github.com/pytorch/torchtitan/pull/1786#discussion_r2476279289 --- .../integration_test_8gpu_features.yaml | 52 ++++++++++++------- tests/integration_tests/run_tests.py | 11 ++-- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index c6e8ed30d5..14e185b5e1 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -25,26 +25,43 @@ permissions: contents: read jobs: + # Step 1: Dynamically compute the matrix based on conditions + set-matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set.outputs.matrix }} + steps: + - id: set + run: | + # Decide which matrix entries to include based on event type + if [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]] || [[ "${{ github.event_name }}" == "schedule" ]]; then + # Include both CUDA and ROCm + echo '{"include":[ + {"name":"cuda","runner":"linux.g5.48xlarge.nvidia.gpu","gpu-arch-type":"cuda","gpu-arch-version":"12.6","docker-image":"torchtitan-ubuntu-20.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/cu126"}, + {"name":"rocm","runner":"linux.rocm.gpu.gfx942.8","gpu-arch-type":"rocm","gpu-arch-version":"7.0","docker-image":"torchtitan-rocm-ubuntu-22.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/rocm7.0"} + ]}' > matrix.json + else + # Include only CUDA + echo '{"include":[ + {"name":"cuda","runner":"linux.g5.48xlarge.nvidia.gpu","gpu-arch-type":"cuda","gpu-arch-version":"12.6","docker-image":"torchtitan-ubuntu-20.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/cu126"} + ]}' > matrix.json + fi + + # Export matrix to job outputs + { + echo 'matrix<> $GITHUB_OUTPUT + + + # Step 2: Use the dynamic matrix in the build-test job build-test: + needs: set-matrix uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: fail-fast: false - matrix: - include: - - name: cuda - runner: linux.g5.48xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: "12.6" - # This image is faster to clone than the default, but it lacks CC needed by triton - # (1m25s vs 2m37s). - docker-image: torchtitan-ubuntu-20.04-clang12 - index-url: https://download.pytorch.org/whl/nightly/cu126 - - name: rocm - runner: linux.rocm.gpu.gfx942.8 - gpu-arch-type: rocm - gpu-arch-version: "7.0" - docker-image: torchtitan-rocm-ubuntu-22.04-clang12 - index-url: https://download.pytorch.org/whl/nightly/rocm7.0 + matrix: ${{ fromJSON(needs.set-matrix.outputs.matrix) }} with: runner: ${{ matrix.runner }} gpu-arch-type: ${{ matrix.gpu-arch-type }} @@ -73,8 +90,7 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - export TEST_WITH_ROCM=$([[ "${{ matrix.gpu-arch-type }}" == "rocm" ]] && echo 1 || echo 0) - python -m tests.integration_tests.run_tests --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index 011fa25554..b2cb8ea503 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -25,9 +25,6 @@ } -TEST_WITH_ROCM = os.getenv("TEST_WITH_ROCM", "0") == "1" - - def _run_cmd(cmd): return subprocess.run([cmd], text=True, shell=True) @@ -92,7 +89,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if TEST_WITH_ROCM and test_flavor.skip_rocm_test: + if args.gpu_arch_type == "rocm" and test_flavor.skip_rocm_test: continue # Check if we have enough GPUs @@ -110,6 +107,12 @@ def main(): parser.add_argument( "output_dir", help="Directory to dump results generated by tests" ) + parser.add_argument( + "--gpu_arch_type", + default="cuda", + choices=["cuda", "rocm"], + help="GPU architecture type. Must be specified as either 'cuda' or 'rocm'.", + ) parser.add_argument( "--test_suite", default="features", From c8ebd7a42552a473fe742c4b41d7cee679368abf Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 19 Nov 2025 10:25:03 -0800 Subject: [PATCH 023/127] Add a loss comparison script (#2029) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2049 * __->__ #2029 ## Summary This PR adds `scripts/loss_compare.py` for comparing training losses between different git commits and/or training configurations. ## Key Features - Commit Comparison: Compare losses between two different git commits with deterministic training - Configuration Comparison: Compare different training configurations on the same commit - Reproducibility: Automatically enables deterministic mode and seed checkpointing for reproducible comparisons - Real-time Output: Streams training output to both console and log files during execution - Statistical Analysis: Generates step-by-step loss comparisons and summary statistics - CI Testing: Includes --assert-equal flag for automated testing to verify identical losses ## Usage Examples #### Compare two commits ``` python3 ./scripts/loss_compare.py main my_branch ``` #### Compare two commits with custom configuration ``` python3 ./scripts/loss_compare.py main my_branch \ --baseline-config="./custom.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ ``` #### Compare different parallelization strategies on same commit ``` python3 ./scripts/loss_compare.py . . \ --baseline-config="./llama3_8b.toml" --baseline-options="--parallelism.tensor_parallel_degree=2" \ --test-options="--parallelism.tensor_parallel_degree=1" \ ``` #### Assert equality for CI testing ``` python3 ./scripts/loss_compare.py main my_branch --assert-equal ``` ## Real Use Cases Compare full dtensor simple fsdp with fsdp2: ``` python3 scripts/loss_compare.py . . \ --baseline-options='--activation_checkpoint.mode="none"' \ --test-train-file='torchtitan.experiments.full_dtensor.train' \ --test-options='--model.name full_dtensor.llama3 --activation_checkpoint.mode="none"' \ --assert-equal --no-seed-checkpoint [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal..LossEqualityTest.test_losses_equal) ... ok ``` --- .../integration_test_8gpu_features.yaml | 7 + scripts/loss_compare.py | 889 ++++++++++++++++++ 2 files changed, 896 insertions(+) create mode 100644 scripts/loss_compare.py diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index 14e185b5e1..de0672eeef 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -92,5 +92,12 @@ jobs: python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + # Verify the accuracy. + echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" + export baseline_options="--parallelism.data_parallel_replicate_degree=1" + export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 + + # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint rm -rf artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py new file mode 100644 index 0000000000..31573c3bce --- /dev/null +++ b/scripts/loss_compare.py @@ -0,0 +1,889 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script compares training losses between different git commits +and/or different training configurations. --debug.deterministic is +always enabled and seed checkpoint is also enabled by default for +reproducible comparisons. You can disable seed checkpoint with +--no-seed-checkpoint if you don't need it to speed up comparisons. +If --output-folder is specified, all outputs are organized in that +folder with detailed analysis and statistical summaries. + +The --assert-equal flag can be used for CI testing to verify that +losses are identical between runs. If losses differ, the script will +exit with a non-zero status code. + +Example usages: +1. Compare losses between two different git commits with default config: + loss_compare.py main my_branch + +2. Compare losses between two commits with custom config and options: + loss_compare.py main my_branch \ + --baseline-config='./custom.toml' \ + --baseline-options='--parallelism.tensor_parallel_degree=2' \ + --output-folder=my_comparison + +3. Compare commits with the same command but skip seed checkpoint for + faster execution: + loss_compare.py main my_branch --no-seed-checkpoint + +4. Compare the same commit with different training configurations: + loss_compare.py . . \ + --baseline-options='--parallelism.dp=1' \ + --test-options='--parallelism.dp=2' + +5. Compare with different train files: + loss_compare.py main my_branch \ + --baseline-train-file='torchtitan.train' \ + --test-train-file='torchtitan.custom_train' + +6. Assert that losses are equal (for CI testing): + loss_compare.py main my_branch --assert-equal +""" + +import argparse +import os +import re +import subprocess +import sys +import unittest +from typing import Any + +# ============================================================================= +# GLOBAL CONFIGURATION +# ============================================================================= + +LOG_PREFIX = "[LOSS_COMPARE]" + +# Fixed options that are always appended +FIXED_OPTIONS = "--debug.deterministic --debug.seed=42" + + +# ============================================================================= +# UTILITY FUNCTIONS +# ============================================================================= + + +def log_print(message: str = "") -> None: + """Print message with LOG_PREFIX.""" + if message: + print(f"{LOG_PREFIX} {message}") + else: + print(f"{LOG_PREFIX}") + + +def get_log_path(scenario: str, output_folder: str | None) -> str: + """Get log file path for a scenario.""" + if output_folder: + return f"{output_folder}/{scenario}_training.log" + return f"/tmp/{scenario}_training.log" + + +def get_loss_file_path(scenario: str, output_folder: str) -> str: + """Get loss file path for a scenario.""" + return f"{output_folder}/{scenario}_losses.txt" + + +def get_clean_log_path(scenario: str, output_folder: str) -> str: + """Get cleaned log file path for a scenario.""" + return f"{output_folder}/{scenario}_training_clean.log" + + +def build_base_command( + config_file: str, train_file: str, options: str, job_dump_folder: str +) -> str: + """Build the base command from config file, train file, and options.""" + cmd = f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' ./run_train.sh" + cmd += f" --job.dump_folder={job_dump_folder}" + if options: + cmd += f" {options}" + return cmd + + +def strip_ansi_codes(input_file: str, output_file: str) -> None: + """Strip ANSI escape codes from log files.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + with open(input_file, "r") as f_in: + with open(output_file, "w") as f_out: + for line in f_in: + f_out.write(ansi_escape.sub("", line)) + + +def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> None: + """Run command with real-time output to both console and log file.""" + log_print(f"Executing: {cmd}") + + # Set PYTHONUNBUFFERED for better output handling + env["PYTHONUNBUFFERED"] = "1" + + # Run command and tee output to both stdout and log file + with open(logfile, "w") as log_f: + process = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + bufsize=1, + ) + + for line in process.stdout: + print(line, end="") + log_f.write(line) + log_f.flush() + + process.wait() + + if process.returncode != 0: + raise subprocess.CalledProcessError(process.returncode, cmd) + + +def log_and_save(message: str, stats_file: str | None) -> None: + """Output message to both stdout and stats file if provided.""" + print(message) + if stats_file: + with open(stats_file, "a") as f: + f.write(message + "\n") + + +# ============================================================================= +# VALIDATION FUNCTIONS +# ============================================================================= + + +def validate_arguments( + baseline_commit: str, + test_commit: str, + baseline_config: str, + baseline_train_file: str, + baseline_options: str, + test_config: str, + test_train_file: str, + test_options: str, + steps: int, +) -> None: + """Validate command line arguments.""" + # Validate commit arguments - if one is ".", both must be "." + if (baseline_commit == "." and test_commit != ".") or ( + baseline_commit != "." and test_commit == "." + ): + log_print("Error: If one commit is '.', both commits must be '.'") + log_print(f" Got baseline: '{baseline_commit}', test: '{test_commit}'") + log_print( + " Use '.' for both commits to compare different " + "configurations on current working directory" + ) + sys.exit(1) + + # Validate that we are comparing different settings + commits_differ = baseline_commit != test_commit + configs_differ = baseline_config != test_config + train_files_differ = baseline_train_file != test_train_file + options_differ = baseline_options != test_options + + if not (commits_differ or configs_differ or train_files_differ or options_differ): + log_print("Error: All settings are identical") + log_print(" Cannot compare identical configurations") + log_print( + " Please provide different commits, configs, train files, or options" + ) + sys.exit(1) + + # Validate steps is a positive integer + if steps <= 0: + log_print(f"Error: --steps must be a positive integer, got: {steps}") + sys.exit(1) + + +# ============================================================================= +# SETUP FUNCTIONS +# ============================================================================= + + +def setup_output_directory(output_folder: str | None) -> str | None: + """Setup output directory and return stats file path. + Returns None if no output folder specified. + """ + if not output_folder: + return None + + # Check if output folder already exists + if os.path.exists(output_folder): + log_print(f"Error: Output folder '{output_folder}' already exists") + log_print(f"Please delete it first: rm -rf {output_folder}") + sys.exit(1) + + # Create the output folder + log_print(f"Creating output folder: {output_folder}") + os.makedirs(output_folder) + + # Set statistics file path + stats_file = os.path.join(output_folder, "comparison_statistics.txt") + return stats_file + + +def build_training_command( + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + job_dump_folder: str, +) -> str: + """Build the final training command with all options.""" + base_cmd = build_base_command(config_file, train_file, options, job_dump_folder) + cmd = f"{base_cmd} {FIXED_OPTIONS} --training.steps={steps}" + if enable_seed_checkpoint: + cmd += ( + " --checkpoint.enable --checkpoint.export_dtype=bfloat16" + " --checkpoint.load_only" + ) + return cmd + + +def print_configuration( + baseline_commit: str, + test_commit: str, + baseline_config: str, + baseline_train_file: str, + baseline_options: str, + test_config: str, + test_train_file: str, + test_options: str, + steps: int, + enable_seed_checkpoint: bool, + job_dump_folder: str, +) -> None: + """Print configuration summary.""" + log_print( + f"Starting loss comparison between baseline commit: " + f"{baseline_commit} and test commit: {test_commit}" + ) + log_print(f"Training steps: {steps}") + log_print(f"Seed checkpoint enabled: {enable_seed_checkpoint}") + log_print() + + # Build and display final commands + baseline_final_cmd = build_training_command( + baseline_config, + baseline_train_file, + baseline_options, + steps, + enable_seed_checkpoint, + job_dump_folder, + ) + test_final_cmd = build_training_command( + test_config, + test_train_file, + test_options, + steps, + enable_seed_checkpoint, + job_dump_folder, + ) + + log_print("Baseline command:") + log_print(f" {baseline_final_cmd}") + log_print() + log_print("Test command:") + log_print(f" {test_final_cmd}") + log_print() + + +# ============================================================================= +# GIT OPERATIONS +# ============================================================================= + + +def checkout_commit(commit: str, commit_name: str) -> None: + """Checkout git commit.""" + if commit != ".": + log_print(f"Checking out {commit_name} commit: {commit}") + subprocess.run(["git", "checkout", commit], check=True) + else: + log_print(f"Using current working directory for {commit_name} (commit: '.')") + + +# ============================================================================= +# TRAINING OPERATIONS +# ============================================================================= + + +def create_seed_checkpoint( + enable_seed_checkpoint: bool, + config_file: str, + train_file: str, + output_folder: str | None, + job_dump_folder: str, +) -> None: + """Create seed checkpoint.""" + if enable_seed_checkpoint: + log_file = get_log_path("seed", output_folder) + log_print(f"Creating seed checkpoint and logging output to {log_file}") + + # Build seed checkpoint command + seed_cmd = ( + f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' " + f"./run_train.sh --job.dump_folder={job_dump_folder} " + f"--checkpoint.create_seed_checkpoint " + f"--checkpoint.enable {FIXED_OPTIONS}" + ) + + env = os.environ.copy() + env["NGPU"] = "1" + + run_with_realtime_output(seed_cmd, log_file, env) + + +def run_training( + scenario: str, + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, + job_dump_folder: str, + ngpus: int, +) -> str: + """Run training for a specific scenario. Returns the log file path.""" + log_file = get_log_path(scenario, output_folder) + log_print( + f"Running training with {scenario} commit and logging output " f"to {log_file}" + ) + + # Build the final command + full_cmd = build_training_command( + config_file, train_file, options, steps, enable_seed_checkpoint, job_dump_folder + ) + + env = os.environ.copy() + env["NGPU"] = str(ngpus) + + run_with_realtime_output(full_cmd, log_file, env) + + return log_file + + +# ============================================================================= +# LOG PROCESSING AND ANALYSIS +# ============================================================================= + + +def extract_losses_from_log(log_file: str) -> dict[int, float]: + """Extract step and loss pairs from a log file.""" + losses = {} + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + + with open(log_file, "r") as f: + for line in f: + # Strip ANSI codes before matching + clean_line = ansi_escape.sub("", line) + match = step_loss_pattern.search(clean_line) + if match: + step, loss = match.groups() + losses[int(step)] = float(loss) + + return losses + + +def read_losses_from_file(loss_file: str) -> dict[int, float]: + """Read losses from a processed loss file.""" + losses = {} + with open(loss_file, "r") as f: + for line in f: + step, loss = line.strip().split() + losses[int(step)] = float(loss) + return losses + + +def extract_loss_data(output_folder: str | None) -> None: + """Extract loss data from logs.""" + if not output_folder: + return + + log_print("Cleaning ANSI escape codes from log files...") + + # Strip ANSI escape codes from log files before processing + scenarios = ["baseline", "test"] + for scenario in scenarios: + strip_ansi_codes( + get_log_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ) + + # Extract step and loss from cleaned logs + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + + for scenario in scenarios: + with open(get_clean_log_path(scenario, output_folder), "r") as f_in: + with open(get_loss_file_path(scenario, output_folder), "w") as f_out: + for line in f_in: + match = step_loss_pattern.search(line) + if match: + step, loss = match.groups() + f_out.write(f"{step} {loss}\n") + + +def generate_step_comparison( + baseline_losses: dict[int, float], + test_losses: dict[int, float], + stats_file: str | None, +) -> None: + """Generate step-by-step comparison.""" + log_and_save("", stats_file) + log_and_save(f"{LOG_PREFIX} Step-by-step loss comparison:", stats_file) + log_and_save( + f"{LOG_PREFIX} Step Baseline Loss Test Loss Difference", + stats_file, + ) + log_and_save( + f"{LOG_PREFIX} ---- ------------- --------- ----------", + stats_file, + ) + + # Generate comparison for common steps + for step in sorted(set(baseline_losses.keys()) & set(test_losses.keys())): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + diff = test_loss - baseline_loss + + formatted_line = ( + f"{LOG_PREFIX} {step:<6} {baseline_loss:<13} " + f"{test_loss:<14} {diff:.6f}" + ) + log_and_save(formatted_line, stats_file) + + +def generate_summary_statistics( + baseline_losses: dict[int, float], + test_losses: dict[int, float], + stats_file: str | None, +) -> None: + """Generate summary statistics.""" + log_and_save(f"{LOG_PREFIX}", stats_file) + log_and_save(f"{LOG_PREFIX} Summary statistics:", stats_file) + + # Calculate average losses + def calculate_average(losses: dict[int, float]) -> float | None: + """Calculate average loss from losses dict.""" + if not losses: + return None + return sum(losses.values()) / len(losses) + + baseline_avg = calculate_average(baseline_losses) + test_avg = calculate_average(test_losses) + + baseline_avg_str = f"{baseline_avg}" if baseline_avg is not None else "N/A" + test_avg_str = f"{test_avg}" if test_avg is not None else "N/A" + + log_and_save(f"{LOG_PREFIX} Average baseline loss: {baseline_avg_str}", stats_file) + log_and_save(f"{LOG_PREFIX} Average test loss: {test_avg_str}", stats_file) + + # Calculate overall difference if both averages are available + if baseline_avg is not None and test_avg is not None: + avg_diff = test_avg - baseline_avg + log_and_save(f"{LOG_PREFIX} Average difference: {avg_diff:.6f}", stats_file) + + +def perform_loss_analysis( + baseline_log: str, test_log: str, stats_file: str | None +) -> None: + """Perform loss comparison analysis.""" + # Initialize stats file and add header + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + log_and_save(f"{LOG_PREFIX} LOSS COMPARISON ANALYSIS", stats_file) + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + + # Extract losses directly from log files + baseline_losses = extract_losses_from_log(baseline_log) + test_losses = extract_losses_from_log(test_log) + + # Check if losses were extracted successfully + name_losses = [("baseline", baseline_losses), ("test", test_losses)] + for name, losses in name_losses: + if not losses: + log_and_save( + f"{LOG_PREFIX} Warning: Could not extract loss data from " + f"{name} training log.", + stats_file, + ) + log_and_save( + f"{LOG_PREFIX} Please check that the training completed " + "successfully.", + stats_file, + ) + return + + # Generate comparison outputs + generate_step_comparison(baseline_losses, test_losses, stats_file) + generate_summary_statistics(baseline_losses, test_losses, stats_file) + + +def assert_losses_equal(baseline_log: str, test_log: str) -> None: + """Assert that losses are equal between baseline and test using + unittest. + """ + log_print("Asserting losses are equal...") + log_print(f"Baseline log: {baseline_log}") + log_print(f"Test log: {test_log}") + + # Extract losses from both logs + baseline_losses = extract_losses_from_log(baseline_log) + test_losses = extract_losses_from_log(test_log) + + log_print(f"Extracted {len(baseline_losses)} steps from baseline log") + log_print(f"Extracted {len(test_losses)} steps from test log") + + if not baseline_losses: + log_print("Error: No losses found in baseline log") + sys.exit(1) + + if not test_losses: + log_print("Error: No losses found in test log") + sys.exit(1) + + # Create a test case + class LossEqualityTest(unittest.TestCase): + def test_losses_equal(self): + # Check that both have the same steps + baseline_steps = set(baseline_losses.keys()) + test_steps = set(test_losses.keys()) + + self.assertEqual( + baseline_steps, + test_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, " + f"test has {len(test_steps)} steps", + ) + + # Check that losses are equal for each step + for step in sorted(baseline_steps): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + self.assertEqual( + baseline_loss, + test_loss, + f"Loss mismatch at step {step}: " + f"baseline={baseline_loss}, test={test_loss}", + ) + + # Run the test + suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + if not result.wasSuccessful(): + log_print("Loss assertion failed!") + sys.exit(1) + else: + log_print("All losses are equal. Assertion passed!") + + +def cleanup_temp_files(output_folder: str | None) -> None: + """Cleanup temporary files.""" + if not output_folder: + return + + scenarios = ["baseline", "test"] + for scenario in scenarios: + for temp_file in [ + get_loss_file_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ]: + if os.path.exists(temp_file): + os.remove(temp_file) + + +# ============================================================================= +# OUTPUT FUNCTIONS +# ============================================================================= + + +def print_completion_summary( + output_folder: str | None, enable_seed_checkpoint: bool +) -> None: + """Print completion summary.""" + log_print() + if output_folder: + log_print(f"Loss comparison complete. Results saved in {output_folder}/:") + log_print(" - baseline_outputs/") + log_print(" - test_outputs/") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint_outputs/") + log_print() + log_print(f"Training logs saved in {output_folder}/:") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint.log") + log_print(" - baseline_training.log") + log_print(" - test_training.log") + log_print() + log_print(f"All outputs organized in: {output_folder}/") + else: + log_print( + "Loss comparison complete. No results saved " + "(no output folder specified)." + ) + + +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + + +def parse_arguments() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description=( + "Compare training losses between different git commits " + "and/or different training configurations." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s abc123 def456 + %(prog)s abc123 def456 --steps=200 + %(prog)s abc123 def456 --baseline-config='./custom.toml' \\ + --baseline-options='--parallelism.tensor_parallel_degree=2' --steps=50 + %(prog)s abc123 def456 --no-seed-checkpoint + %(prog)s . . --baseline-options='--parallelism.dp=1' \\ + --test-options='--parallelism.dp=2' --steps=30 + """, + ) + + parser.add_argument("baseline_commit", help="Git commit hash for baseline") + parser.add_argument("test_commit", help="Git commit hash for test") + parser.add_argument( + "--baseline-config", + default="./torchtitan/models/llama3/train_configs/debug_model.toml", + help=( + "Config file for baseline run " + "(default: ./torchtitan/models/llama3/train_configs/" + "llama3_debug.toml)" + ), + ) + parser.add_argument( + "--test-config", + default="", + help="Config file for test run (default: uses baseline-config)", + ) + parser.add_argument( + "--baseline-options", + default="", + help="Additional CLI arguments for baseline run (default: empty)", + ) + parser.add_argument( + "--test-options", + default="", + help="Additional CLI arguments for test run (default: empty)", + ) + parser.add_argument( + "--baseline-train-file", + default="torchtitan.train", + help=( + "Train file (Python module path) for baseline run " + "(default: torchtitan.train)" + ), + ) + parser.add_argument( + "--test-train-file", + default="", + help=( + "Train file (Python module path) for test run " + "(default: uses baseline-train-file)" + ), + ) + parser.add_argument( + "--steps", + type=int, + default=100, + help="Number of training steps (default: 100)", + ) + parser.add_argument( + "--no-seed-checkpoint", + action="store_true", + help=("Disable seed checkpoint creation and checkpoint functionality"), + ) + parser.add_argument( + "--output-folder", + default="", + help=( + "Output folder for results (optional, if not specified, " + "results will not be saved)" + ), + ) + parser.add_argument( + "--assert-equal", + action="store_true", + help=( + "Assert that all losses are equal (for CI testing). " + "Script exits with error if losses differ." + ), + ) + parser.add_argument( + "--job-dump-folder", + default="outputs", + help="Job dump folder path (default: outputs)", + ) + parser.add_argument( + "--baseline-ngpus", + type=int, + default=8, + help="Number of GPUs for baseline run (default: 8)", + ) + parser.add_argument( + "--test-ngpus", + type=int, + default=8, + help="Number of GPUs for test run (default: 8)", + ) + + args = parser.parse_args() + + # Set default values if not provided + if not args.test_config: + args.test_config = args.baseline_config + + if not args.test_train_file: + args.test_train_file = args.baseline_train_file + + # Convert empty output_folder to None + if not args.output_folder: + args.output_folder = None + + return args + + +def run_scenario( + scenario: str, + commit: str, + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, + job_dump_folder: str, + ngpus: int, +) -> str: + """Run training for a specific scenario (baseline or test). + + Args: + scenario: Name of the scenario ("baseline" or "test") + commit: Git commit to checkout + config_file: Config file path + train_file: Train file (Python module path) + options: Additional CLI options + steps: Number of training steps + enable_seed_checkpoint: Whether to use seed checkpoint + output_folder: Output folder for results + job_dump_folder: Job dump folder path + ngpus: Number of GPUs to use + + Returns: + Path to the log file + """ + checkout_commit(commit, scenario) + + log_file = run_training( + scenario, + config_file, + train_file, + options, + steps, + enable_seed_checkpoint, + output_folder, + job_dump_folder, + ngpus, + ) + + return log_file + + +def main() -> None: + """Main function that orchestrates the entire comparison process.""" + # Parse and validate arguments + args = parse_arguments() + validate_arguments( + args.baseline_commit, + args.test_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + ) + + # Setup environment + stats_file = setup_output_directory(args.output_folder) + enable_seed_checkpoint = not args.no_seed_checkpoint + print_configuration( + args.baseline_commit, + args.test_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + enable_seed_checkpoint, + args.job_dump_folder, + ) + + create_seed_checkpoint( + enable_seed_checkpoint, + args.baseline_config, + args.baseline_train_file, + args.output_folder, + args.job_dump_folder, + ) + # Run baseline and test training + baseline_log = run_scenario( + "baseline", + args.baseline_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.baseline_ngpus, + ) + + test_log = run_scenario( + "test", + args.test_commit, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.test_ngpus, + ) + log_print() + + # Assert losses are equal if requested + if args.assert_equal: + assert_losses_equal(baseline_log, test_log) + + # Analysis and reporting + perform_loss_analysis(baseline_log, test_log, stats_file) + cleanup_temp_files(args.output_folder) + print_completion_summary(args.output_folder, enable_seed_checkpoint) + + +if __name__ == "__main__": + main() From 605a9a129f3e14896ebdf2c2a35e841ba42cf9e9 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:37:56 -0800 Subject: [PATCH 024/127] Fix integration test gpu_arch_type field (#2060) All tests in experiments are broken due to the `gpu_arch_type` field added in #2018. --- tests/integration_tests/run_tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index b2cb8ea503..c233904165 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -89,7 +89,10 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if args.gpu_arch_type == "rocm" and test_flavor.skip_rocm_test: + if ( + getattr(args, "gpu_arch_type", "cuda") == "rocm" + and test_flavor.skip_rocm_test + ): continue # Check if we have enough GPUs From f541d91cfd902f7fc2a48bd32ba67dcb9514a097 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:50:12 -0800 Subject: [PATCH 025/127] [compiler toolkit] Add Trainer subclass for compiler toolkit (#2064) Adding CudaGraph pass (https://github.com/pytorch/torchtitan/pull/2050) would require some custom logic in Trainer's close() method. So we create a Trainer subclass in compiler toolkit --- .../integration_test_8gpu_compiler_toolkit.yaml | 2 +- .../experiments/compiler_toolkit/README.md | 16 ++++++++-------- torchtitan/experiments/compiler_toolkit/train.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 torchtitan/experiments/compiler_toolkit/train.py diff --git a/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml b/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml index 1aee67c093..815476e82c 100644 --- a/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml +++ b/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml @@ -50,4 +50,4 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded - python -m torchtitan.experiments.compiler_toolkit.tests.integration_tests artifacts-to-be-uploaded --ngpu 4 + TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train python -m torchtitan.experiments.compiler_toolkit.tests.integration_tests artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index c223d1e658..7d00e1f48b 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -14,44 +14,44 @@ Joint Graph based Training Prototype: **SimpleFSDP + TP + EP** ```shell -NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none ``` **SimpleFSDP + TP + EP + FlexAttention** ```shell -NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn ``` ## llama3 **SimpleFSDP + TP** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 ``` **SimpleFSDP + TP + auto-bucketing** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering ``` **SimpleFSDP + TP + transformer-block-bucketing** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` **SimpleFSDP + TP + FlexAttention** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn ``` **SimpleFSDP + TP + FlexAttention + auto-bucketing + regional-inductor** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` **SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor ``` diff --git a/torchtitan/experiments/compiler_toolkit/train.py b/torchtitan/experiments/compiler_toolkit/train.py new file mode 100644 index 0000000000..26e3245b2b --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/train.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.train import main, Trainer + + +class CompilerToolkitTrainer(Trainer): + pass + + +if __name__ == "__main__": + main(CompilerToolkitTrainer) From 8bf2265e9d55cbf10eec23d9603de7763701c901 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 19 Nov 2025 17:59:25 -0800 Subject: [PATCH 026/127] Let loss_compare.py check the repo cleaness (#2062) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2063 * __->__ #2062 This will prevent errors when later doing git checkout --- scripts/loss_compare.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 31573c3bce..42ad3a81be 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -301,6 +301,35 @@ def print_configuration( # ============================================================================= +def check_git_clean_state() -> None: + """Check if git working directory is clean before switching commits. + + Raises SystemExit if there are uncommitted changes or untracked files. + """ + result = subprocess.run( + ["git", "status", "--porcelain"], + capture_output=True, + text=True, + check=True, + ) + + if result.stdout.strip(): + log_print("Error: Git working directory is not clean") + log_print(" Cannot switch commits with uncommitted changes") + log_print("") + log_print("Modified/untracked files:") + for line in result.stdout.strip().split("\n"): + log_print(f" {line}") + log_print("") + log_print( + "Please commit, stash, or discard your changes before running this script" + ) + log_print(" - To commit: git add -A && git commit -m 'message'") + log_print(" - To stash: git stash") + log_print(" - To discard: git checkout -- . && git clean -fd") + sys.exit(1) + + def checkout_commit(commit: str, commit_name: str) -> None: """Checkout git commit.""" if commit != ".": @@ -840,6 +869,12 @@ def main() -> None: args.job_dump_folder, ) + # Check if git working directory is clean before switching commits + # Skip check if both commits are "." (comparing configs on same commit) + needs_git_checkout = args.baseline_commit != "." or args.test_commit != "." + if needs_git_checkout: + check_git_clean_state() + create_seed_checkpoint( enable_seed_checkpoint, args.baseline_config, From f5e3a84eeede490ebb684f8f66c35315f3263b40 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 19 Nov 2025 21:15:37 -0800 Subject: [PATCH 027/127] CUDAGraph support for SimpleFSDP and TP (#2050) ## Features - [x] Support SimpleFSDP and TP - [x] Support static input indices to reduce copy - [x] Support memory reuse to reduce memory consumption - [x] Cleanup cudagraph when training finishes to avoid nccl hang from destroy_process_group Command: ``` NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes cudagraph ``` Note: we use `NCCL_GRAPH_REGISTER=0` due to a known issue that nccl + cudagraphs + expandable segments result in IMA. https://github.com/pytorch/pytorch/issues/158029 [trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces%2Ftree%2Fshared_trace%2Fboyuan_e1ef464b-ee61-4c61-82e5-f7a485e561bf_rank0_trace.json) ## Result **Numerics:** Achieved bitwise equivalence w/ and w/o cudagraph pass on llama3.1-8B AND llama3.1-70B. **Performance:** image Raw log: [llama3-8b](https://www.internalfb.com/phabricator/paste/view/P2045444190), [llama3-70b](https://www.internalfb.com/phabricator/paste/view/P2045567416) **Memory:** On llama3.1-70b, cudagraph takes 6% more memory consumption (143 GiB vs 153 GiB). A few tricks to reduce memory consumption (use llama3.1-70b w/ cudagraph as an example): - Start: 161 GiB - \+ use the same stream for warmup and graph capture of both fwd and bwd: 160 GiB - \+ warmup in cudagraph memory pool instead of eager memory pool: 153 GiB **static input copy:** On llama3.1-70B, for forward, we copy 1 tensor of 128 bytes; for backward, we copy 1 tensor of 0.98 GB. This shows static input indices is handled correctly. ## Followup PR In the followup PR, I will enable fx graph partition for deepseek v3 https://github.com/pytorch/pytorch/pull/165945. --- .../experiments/compiler_toolkit/README.md | 6 + .../compiler_toolkit/common_utils.py | 9 + .../experiments/compiler_toolkit/cudagraph.py | 169 ++++++++++++++++++ .../compiler_toolkit/graph_utils.py | 50 +++++- .../experiments/compiler_toolkit/passes.py | 24 +++ .../tests/integration_tests.py | 29 +++ .../experiments/compiler_toolkit/train.py | 15 +- 7 files changed, 292 insertions(+), 10 deletions(-) create mode 100644 torchtitan/experiments/compiler_toolkit/cudagraph.py diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 7d00e1f48b..620911ce60 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -55,3 +55,9 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to ```shell NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor ``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** + +```shell +NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +``` diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..997af9a2c4 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from contextlib import contextmanager +from typing import Callable import torch from torch.distributed.tensor import DTensor, Replicate @@ -53,3 +54,11 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py new file mode 100644 index 0000000000..cd6e4cfc22 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +CUDAGraph pass for the compiler toolkit. + +This module provides a cudagraph pass that can be applied to graph modules +during compilation. +""" + +import warnings +from typing import Any, Callable, Optional, Sequence + +import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager +from torch.utils._ordered_set import OrderedSet + + +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + example_inputs: Sequence[Any], + static_input_indices: Optional[tuple[int]] = None, + should_check_address: bool = False, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False + + self.args = None + self.output = None + + # (debug only) whether check static input tensor addresses during runtime + self.should_check_address = should_check_address + + def copy_non_static_inputs(self, *args): + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) + + def check_input_types(self, inputs) -> None: + for inp in inputs: + assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( + "args must be tensor, integer (for dynamic shapes), " + "or Generator (for random number generator), " + f"but found {type(inp)}" + ) + + def check_static_inputs_address(self) -> None: + for i in self.static_input_indices: + actual = args[i].data_ptr() + expected = self.input_addresses[i] + assert expected == actual, ( + "Expected the same static tensor address but found " + f"{expected} != {actual}" + ) + + def __call__(self, *args): + if not self.has_warmup: + self.has_warmup = True + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out + + if self.cudagraph is None: + self.check_input_types(args) + self.args = args + self.input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): + # `output` is managed by pytorch's cudagraph pool + self.output = self.runnable(*args) + + if self.should_check_address: + self.check_static_inputs_address() + + self.copy_non_static_inputs(*args) + self.cudagraph.replay() + return self.output + + +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 51ac8ba983..e097579cc0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -20,6 +20,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass from torchtitan.tools.logging import logger @@ -217,6 +218,7 @@ def compiler( example_inputs, passes: List[Callable] = None, dump_folder: str | None = None, + is_forward: bool = True, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -239,6 +241,17 @@ def compiler( ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") + if end_with_pass(passes, ["cudagraph_pass"]): + # cudagraph pass is always the last pass if it is applied + cg_pass = passes[-1] + + # to identify static input indices, cudagraph passes behaves differently for + # forward and backward pass. so we explicitly pass the info. + _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(cg_pass)(_cg_pass) + for pass_fn in passes: pass_name = ( pass_fn.func.__name__ @@ -271,17 +284,42 @@ def make_compiler_with_passes( def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, ) return fw_compiler, bw_compiler +def validate_pass_names(pass_names: list[str]) -> None: + if "cudagraph" in pass_names: + assert ( + pass_names[-1] == "cudagraph" + ), "cudagraph has to be the last pass to apply" + + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) + + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. @@ -298,13 +336,7 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi ) pass_names = getattr(job_config.compile, "passes", []) - if ( - "autobucketing_reordering" in pass_names - and "transformer_block_bucketing" in pass_names - ): - raise ValueError( - "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" - ) + validate_pass_names(pass_names) compiler_passes = [] for pass_name in pass_names: diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 64276a91bc..5657eb2b2b 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -11,10 +11,16 @@ during compilation. Passes can be selected and configured via job config. """ +from typing import Any, Sequence + import torch from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.compiler_toolkit.cudagraph import ( + CUDAGraphWrapper, + get_static_input_indices, +) from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) @@ -56,6 +62,23 @@ def regional_inductor_pass( return regional_inductor(gm, example_inputs) +def cudagraph_pass( + gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool +) -> torch.fx.GraphModule: + """ + Apply cudagraph. + + This pass wraps the forward function with cudagraph during compilation and does + not record cudagraph until runtime. + - For the first run, it will warm up operators such as nccl. + - For the second run, it will record cudagraph and replay cudagraph. + - For the following runs, it will replay cudagraph. + """ + static_input_indices = get_static_input_indices(gm, is_forward) + gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) + return gm + + def validate_flex_attn_annotation_pass( gm: torch.fx.GraphModule, ) -> torch.fx.GraphModule: @@ -88,4 +111,5 @@ def fsdp_reshard_after_fwd_pass( "autobucketing_reordering": autobucketing_reordering_pass, "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, "regional_inductor": regional_inductor_pass, + "cudagraph": cudagraph_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index b0155a9f2a..f01a1c4380 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -58,6 +58,20 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_manualbucketing", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes cudagraph", + ], + ], + "llama3 FSDP+TP+cudagraph", + "llama3_fsdp_tp_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -86,6 +100,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", + ], + ], + "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", + "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/experiments/compiler_toolkit/train.py b/torchtitan/experiments/compiler_toolkit/train.py index 26e3245b2b..7b0d58aa5a 100644 --- a/torchtitan/experiments/compiler_toolkit/train.py +++ b/torchtitan/experiments/compiler_toolkit/train.py @@ -4,11 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc + from torchtitan.train import main, Trainer class CompilerToolkitTrainer(Trainer): - pass + def close(self) -> None: + super().close() + + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. + for part in self.model_parts: + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None + gc.collect() if __name__ == "__main__": From d167a20e4ee8c9913810811509eee722d25b3204 Mon Sep 17 00:00:00 2001 From: Masaki Date: Thu, 20 Nov 2025 09:47:46 -0800 Subject: [PATCH 028/127] compiler_toolkit: fix args access (#2067) This PR fixes access to args; it's an attribute, not a variable in the scope. The method itself though would not be used because `should_check_address` seems to be always `False` and there doesn't seem to be a command line argument for it. Signed-off-by: Masaki Kozuki --- torchtitan/experiments/compiler_toolkit/cudagraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py index cd6e4cfc22..d008e5d455 100644 --- a/torchtitan/experiments/compiler_toolkit/cudagraph.py +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -98,7 +98,7 @@ def check_input_types(self, inputs) -> None: def check_static_inputs_address(self) -> None: for i in self.static_input_indices: - actual = args[i].data_ptr() + actual = self.args[i].data_ptr() expected = self.input_addresses[i] assert expected == actual, ( "Expected the same static tensor address but found " From 58fa181ed3543e19c1cff3014f1b61b919d38cd1 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom <47445085+3outeille@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:47:31 +0100 Subject: [PATCH 029/127] 3outeille/transformers backend (Dense model only) (#2048) # Context Reference PR: https://github.com/huggingface/torchtitan/pull/1 This PR enables: - Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP (and the combinations between them). The following models were tested: - `meta-llama/Llama-3.2-1B` - `microsoft/phi-2` - `Qwen/Qwen2.5-7B` - `mistralai/Mistral-7B-v0.1` - `ByteDance-Seed/Seed-Coder-8B-Instruct` - `Qwen/Qwen3-4B-Instruct-2507` - `arcee-ai/AFM-4.5B` - `ibm-granite/granite-3b-code-base-2k` - `baidu/ERNIE-4.5-0.3B-Base-PT` - `kyutai/helium-1-preview-2b` - `allenai/OLMo-7B-hf` - `mistralai/Ministral-8B-Instruct-2410` - Patching HF models weights initialisation. Without this, the the `loss` and `grad_norm` starts very high # Usage - Requirements `transformers==4.57.1` - Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` ```diff ... [model] - name = "llama3" + name = "transformers_backend" flavor = "debugmodel" hf_assets_path = "./tests/assets/tokenizer" +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" ... ``` - Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` image # Testing methodology image - Following the [converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md) guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & ` - More precisely, the `test_hf_integration.py`is going to do: ```bash results/ |_ meta-llama |_ Llama-3.2-1B |_ debugmodel/ |_ seed_checkpoint/ |_ config.toml |_ seed.slurm |_ step-0/ |_ .... |_ fsdp2_tp1_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ fsdp2_tp2_cp1_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp1_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp1/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log |_ fsdp2_tp1_cp2_pp2/ |_ config.toml |_ nd_parallelism.slurm |_ nd_parallelism.log |_ diff_baseline_vs_nd_parallelism.log` |_ full/ ... ``` - Here is the grid search to test the HF modelling ```shell #!/usr/bin/bash model_names=( "meta-llama/Llama-3.2-1B" "microsoft/phi-2" "Qwen/Qwen2.5-7B" "mistralai/Mistral-7B-v0.1" "ByteDance-Seed/Seed-Coder-8B-Instruct" "Qwen/Qwen3-4B-Instruct-2507" "arcee-ai/AFM-4.5B" "ibm-granite/granite-3b-code-base-2k" "baidu/ERNIE-4.5-0.3B-Base-PT" "kyutai/helium-1-preview-2b" "allenai/OLMo-7B-hf" "mistralai/Ministral-8B-Instruct-2410" ) for model_name in "${model_names[@]}"; do rm -rf slurm_results/${model_name} python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do echo "Waiting for seed checkpoint from ${model_name} to complete ..." sleep 1 done python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high echo "================" done ``` # Further tasks - Moe (handle in PR https://github.com/huggingface/torchtitan/pull/3) - Missing `build_optimizers_with_moe_load_balancing` support for MoE - Missing TP/PP/EP supports for MoE - When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. (issue is tracked in https://github.com/huggingface/torchtitan/pull/4) - Add convergence tests to CI by doing tiny model + gloo backend (once PP is bitwise matching) - the HF modeling has lower MFU than Torchtitan MFU - NOTE: `import torch._dynamo.config; torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for graph when using `torch.compile` and `activation checkpointing` --- .ci/docker/common/install_conda.sh | 1 + .../requirements-transformers-backend.txt | 1 + .ci/docker/ubuntu/Dockerfile | 1 + ...ration_test_8gpu_transformers_backend.yaml | 53 ++ torchtitan/experiments/README.md | 1 + torchtitan/experiments/__init__.py | 1 + .../transformers_backend/README.md | 52 ++ .../transformers_backend/__init__.py | 51 ++ .../configs/debug_model.toml | 88 ++++ .../transformers_backend/configs/full.toml | 87 ++++ .../transformers_backend/infra/parallelize.py | 435 ++++++++++++++++ .../transformers_backend/infra/pipeline.py | 391 ++++++++++++++ .../transformers_backend/job_config.py | 18 + .../transformers_backend/model/args.py | 199 ++++++++ .../transformers_backend/model/model.py | 477 ++++++++++++++++++ .../tests/integration_tests.py | 72 +++ torchtitan/train.py | 10 +- 17 files changed, 1936 insertions(+), 2 deletions(-) create mode 100644 .ci/docker/requirements-transformers-backend.txt create mode 100644 .github/workflows/integration_test_8gpu_transformers_backend.yaml create mode 100644 torchtitan/experiments/transformers_backend/README.md create mode 100644 torchtitan/experiments/transformers_backend/__init__.py create mode 100644 torchtitan/experiments/transformers_backend/configs/debug_model.toml create mode 100644 torchtitan/experiments/transformers_backend/configs/full.toml create mode 100644 torchtitan/experiments/transformers_backend/infra/parallelize.py create mode 100644 torchtitan/experiments/transformers_backend/infra/pipeline.py create mode 100644 torchtitan/experiments/transformers_backend/job_config.py create mode 100644 torchtitan/experiments/transformers_backend/model/args.py create mode 100644 torchtitan/experiments/transformers_backend/model/model.py create mode 100644 torchtitan/experiments/transformers_backend/tests/integration_tests.py diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index c2f316b04b..d3cb20e7a3 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -43,6 +43,7 @@ install_pip_dependencies() { pip_install -r /opt/conda/requirements.txt pip_install -r /opt/conda/requirements-flux.txt pip_install -r /opt/conda/requirements-vlm.txt + pip_install -r /opt/conda/requirements-transformers-backend.txt popd } diff --git a/.ci/docker/requirements-transformers-backend.txt b/.ci/docker/requirements-transformers-backend.txt new file mode 100644 index 0000000000..76e8886ed0 --- /dev/null +++ b/.ci/docker/requirements-transformers-backend.txt @@ -0,0 +1 @@ +transformers==4.57.1 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index baaca85824..b8123099b9 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -33,6 +33,7 @@ COPY requirements-dev.txt /opt/conda/ COPY requirements.txt /opt/conda/ COPY requirements-flux.txt /opt/conda/ COPY requirements-vlm.txt /opt/conda/ +COPY requirements-transformers-backend.txt /opt/conda/ COPY conda-env-ci.txt /opt/conda/ COPY ./common/install_conda.sh install_conda.sh COPY ./common/utils.sh utils.sh diff --git a/.github/workflows/integration_test_8gpu_transformers_backend.yaml b/.github/workflows/integration_test_8gpu_transformers_backend.yaml new file mode 100644 index 0000000000..aea5189d81 --- /dev/null +++ b/.github/workflows/integration_test_8gpu_transformers_backend.yaml @@ -0,0 +1,53 @@ +name: Transformers Backend 8 GPU Integration Tests + +on: + push: + branches: [ main ] + paths: + - 'torchtitan/experiments/transformers_backend/**' + pull_request: + paths: + - 'torchtitan/experiments/transformers_backend/**' + schedule: + # Runs every 12 hours + - cron: '0 */12 * * *' + +concurrency: + group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-test: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.48xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + # This image is faster to clone than the default, but it lacks CC needed by triton + # (1m25s vs 2m37s). + docker-image: torchtitan-ubuntu-20.04-clang12 + repository: pytorch/torchtitan + upload-artifact: outputs + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) + echo "CUDA driver version: ${DRIVER_VERSION}" + + pip config --user set global.progress_bar off + + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + + USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + + mkdir artifacts-to-be-uploaded + python -m torchtitan.experiments.transformers_backend.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 14f8ba6544..08dc692bf9 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -31,3 +31,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | +| [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index f6f813bfae..db3a44a824 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -12,5 +12,6 @@ "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", + "transformers_backend", ] ) diff --git a/torchtitan/experiments/transformers_backend/README.md b/torchtitan/experiments/transformers_backend/README.md new file mode 100644 index 0000000000..805afb9ab9 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/README.md @@ -0,0 +1,52 @@ +# Huggingface Transformers backend + +## Quick start + +- Requirements `transformers==4.57.1` + +- Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` +```diff +... +[model] +- name = "llama3" ++ name = "transformers_backend" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + ++[hf_transformers] ++model = "Qwen/Qwen3-4B-Instruct-2507" +... +``` +- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` + - Make sure you have created the tokenizers beforehand +image + +## Supported Features + +- The following models were tested: + - Dense (FSDP/CP/TP/PP/`torch.compile`) + - `meta-llama/Llama-3.2-1B` + - `microsoft/phi-2` + - `Qwen/Qwen2.5-7B` + - `mistralai/Mistral-7B-v0.1` + - `ByteDance-Seed/Seed-Coder-8B-Instruct` + - `Qwen/Qwen3-4B-Instruct-2507` + - `arcee-ai/AFM-4.5B` + - `ibm-granite/granite-3b-code-base-2k` + - `baidu/ERNIE-4.5-0.3B-Base-PT` + - `kyutai/helium-1-preview-2b` + - `allenai/OLMo-7B-hf` + - `mistralai/Ministral-8B-Instruct-2410` + - MoE (upcoming) + +## Known issues to address later + +- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. This will be addressed in another PR but the culprit is probably `register_buffer` when loading `seed_checkpoint` +- the HF modeling has lower MFU than Torchtitan MFU + +## Further work + +- Missing `build_optimizers_with_moe_load_balancing` support for MoE +- Missing TP/PP/EP supports for MoE +- Load HF weights +- Add LORA support diff --git a/torchtitan/experiments/transformers_backend/__init__.py b/torchtitan/experiments/transformers_backend/__init__.py new file mode 100644 index 0000000000..aec28a0bdd --- /dev/null +++ b/torchtitan/experiments/transformers_backend/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_hf_transformers + +from .infra.pipeline import pipeline_hf_transformers +from .model.args import HFTransformerModelArgs, TitanDenseModelArgs +from .model.model import HFTransformerModel + +__all__ = [ + "HFTransformerModelArgs", + "HFTransformerModel", +] + + +flavors = { + "debugmodel": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs( + dim=256, + n_layers=2, + n_heads=16, + n_kv_heads=16, + ), + ), + "full": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs(), + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=HFTransformerModel, + model_args=flavors, + parallelize_fn=parallelize_hf_transformers, + pipelining_fn=pipeline_hf_transformers, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/transformers_backend/configs/debug_model.toml b/torchtitan/experiments/transformers_backend/configs/debug_model.toml new file mode 100644 index 0000000000..7b3de04b87 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/configs/debug_model.toml @@ -0,0 +1,88 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 debug training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_backend" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +dataset_path = "./tests/assets/c4_test" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_backend/configs/full.toml b/torchtitan/experiments/transformers_backend/configs/full.toml new file mode 100644 index 0000000000..45eaa785de --- /dev/null +++ b/torchtitan/experiments/transformers_backend/configs/full.toml @@ -0,0 +1,87 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 full training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_backend" +flavor = "full" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_backend/infra/parallelize.py b/torchtitan/experiments/transformers_backend/infra/parallelize.py new file mode 100644 index 0000000000..b2ae3f02a1 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/infra/parallelize.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac + +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp +from torchtitan.tools.logging import logger + + +def parallelize_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + model.set_cp_mesh(world_mesh["cp"]) + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + + # skipping nn.Identity modules (which are added by pipeline parallelism for unused modules) + root_plan = {} + + if hasattr(model, "tok_embeddings"): + if isinstance(model.tok_embeddings, nn.Identity): + root_plan["tok_embeddings"] = NoParallel() + else: + root_plan["tok_embeddings"] = RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + + if hasattr(model, "norm"): + if isinstance(model.norm, nn.Identity): + root_plan["norm"] = NoParallel() + else: + root_plan["norm"] = SequenceParallel() + + if hasattr(model, "output"): + if isinstance(model.output, nn.Identity): + root_plan["output"] = NoParallel() + else: + root_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ) + if root_plan: # Only call if there's something to parallelize + parallelize_module(model, tp_mesh, root_plan) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers: + layer_plan = { + "input_layernorm": SequenceParallel(), + "self_attn": prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "post_attention_layernorm": SequenceParallel(), + } + + if getattr(transformer_block.self_attn, "q_lora_rank", None) is None: + layer_plan.update( + { + "self_attn.q_proj": colwise_parallel(), + "self_attn.k_proj": colwise_parallel(), + "self_attn.v_proj": colwise_parallel(), + } + ) + else: + layer_plan.update( + { + "self_attn.q_a_proj": NoParallel(), + "self_attn.q_a_layernorm": NoParallel(), + "self_attn.q_b_proj": colwise_parallel(), + "self_attn.kv_a_proj_with_mqa": NoParallel(), + "self_attn.kv_a_layernorm": NoParallel(), + "self_attn.kv_b_proj": colwise_parallel(), + } + ) + + # Handle different names for the output projection layer, e.g. o_proj vs dense + o_proj_name = ( + "o_proj" if hasattr(transformer_block.self_attn, "o_proj") else "dense" + ) + layer_plan[f"self_attn.{o_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + # For model that uses RMSNorm on Q and K (i.e. Qwen3) + if hasattr(transformer_block.self_attn, "q_norm") and hasattr( + transformer_block.self_attn, "k_norm" + ): + layer_plan["self_attn.q_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + layer_plan["self_attn.k_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + + if not transformer_block.moe_enabled: + mlp_plan = { + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + } + # Handle different names for MLP layers, e.g. gate_proj vs fc1 + gate_proj_name = ( + "gate_proj" if hasattr(transformer_block.mlp, "gate_proj") else "fc1" + ) + mlp_plan[f"mlp.{gate_proj_name}"] = colwise_parallel() + + if hasattr(transformer_block.mlp, "up_proj"): + mlp_plan["mlp.up_proj"] = colwise_parallel() + + down_proj_name = ( + "down_proj" if hasattr(transformer_block.mlp, "down_proj") else "fc2" + ) + mlp_plan[f"mlp.{down_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + layer_plan.update(mlp_plan) + + # Some models like Phi-2 don't have post_attention_layernorm + if not hasattr(transformer_block, "post_attention_layernorm"): + layer_plan.pop("post_attention_layernorm") + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + ep_degree: int = 1, + dp_mod_ep_mesh: DeviceMesh | None = None, + gradient_divide_factor: int | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + for transformer_block in model.layers: + # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping + # - the router and the shared experts are sharded together with the TransformerBlock + # - the routed experts are sharded with the remaining dp_mod_ep_mesh + if ( + hasattr(transformer_block, "moe_enabled") + and transformer_block.moe_enabled + and ep_degree > 1 + ): + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + moe_block = transformer_block.mlp + # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). + # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding + # causes inefficiency, so we choose to do FSDP sharding on dim-1. + # Even when EP is not used, we may still want to shard the experts + # on non-0 dim. For now it may not be worth the complexity to support + # shard_placement_fn on the outer TransformerBlock-level FSDP. + _experts_shard_placement_fn = None + assert dp_mod_ep_mesh is not None + if dp_mod_ep_mesh.size() * ep_degree > moe_block.experts.num_experts: + _experts_shard_placement_fn = lambda param: Shard(1) + + fully_shard( + moe_block.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, + ) + + # NOTE: # Although the FSDP sharding of experts is done on a mesh of + # a different size than other parameters, the gradient division + # factor should be consistent with data. + moe_block.experts.set_gradient_divide_factor( + gradient_divide_factor, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + + fully_shard(model, **fsdp_config) + + # NOTE: set up explicit prefetching when EP is enabled, as D2H syncs + # in EP could interfere with implicit prefetching in FSDP + if ep_degree == 1: + return + + # forward + transformer_blocks = list(model.layers.values()) + next_transformer_blocks = transformer_blocks[1:] + [None] + + if model.tok_embeddings is not None and model.layers is not None: + model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) + + for transformer_block, next_transformer_block in zip( + transformer_blocks, next_transformer_blocks + ): + if next_transformer_block is not None: + if next_transformer_block.moe_enabled: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block, next_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block] + ) + elif model.norm is not None and model.output is not None: + transformer_block.set_modules_to_forward_prefetch( + [model.norm, model.output] + ) + + # backward + reversed_transformer_blocks = list(reversed(model.layers.values())) + prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + + if model.norm is not None and model.output is not None and model.layers is not None: + model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) + + for transformer_block, prev_transformer_block in zip( + reversed_transformer_blocks, prev_transformer_blocks + ): + if prev_transformer_block is not None: + if prev_transformer_block.moe_enabled: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block, prev_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block] + ) + elif model.tok_embeddings is not None: + transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) diff --git a/torchtitan/experiments/transformers_backend/infra/pipeline.py b/torchtitan/experiments/transformers_backend/infra/pipeline.py new file mode 100644 index 0000000000..04452c5ede --- /dev/null +++ b/torchtitan/experiments/transformers_backend/infra/pipeline.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) + +from torchtitan.components.loss import LossFunction +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline_parallel import build_pipeline_schedule +from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger + +# NOTE(3outeille): the only modifications comes from replacing None to nn.Identity and adding rotary_emb per model_part + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (embed_tokens) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + Returns: + List of lists containing module names for each model part + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output", "rotary_emb"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + stage_modules.append("rotary_emb") + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_degree = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with Identity + setattr(model, module_name, nn.Identity()) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + + def _get_stage_indices() -> tuple[int]: + """ + Compute the stage ids for the stages that will run on this pp rank + for either a looped or V style schedule + """ + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return stage_v_pairs[pp_rank] + + for stage_idx in _get_stage_indices(): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models + + +def pipeline_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage diff --git a/torchtitan/experiments/transformers_backend/job_config.py b/torchtitan/experiments/transformers_backend/job_config.py new file mode 100644 index 0000000000..f3b1667798 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/job_config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class HFTransformers: + model: str = "" + """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" + + +@dataclass +class JobConfig: + hf_transformers: HFTransformers = field(default_factory=HFTransformers) diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_backend/model/args.py new file mode 100644 index 0000000000..25ab328f15 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/model/args.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from torch import nn +from torchtitan.config.job_config import JobConfig +from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.protocols import BaseModelArgs +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.integrations.sdpa_attention import sdpa_attention_forward +from transformers.modeling_utils import AttentionInterface + + +@dataclass +class TitanDenseModelArgs: + """Arguments for the base TorchTitan model.""" + + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int | None = None + multiple_of: int = 256 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + max_seq_len: int = 2048 + depth_init: bool = True + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + +@dataclass +class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): + """ + Configuration class that bridges TorchTitan and HuggingFace Transformers naming conventions. + + Uses properties to provide TorchTitan-style access while maintaining HuggingFace compatibility. + Properties are created dynamically based on which arguments are provided. + """ + + # Define all possible mappings organized by argument type + _TT_TO_HF_MAPPINGS = { + "dense": { + # TorchTitan dense model mappings (always available) + "dim": "hidden_size", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "n_kv_heads": "num_key_value_heads", + "norm_eps": "rms_norm_eps", + "max_seq_len": "max_position_embeddings", + "eos_id": "eos_token_id", + } + } + + # Declarative list of TorchTitan-only attributes (no HF equivalent) + _TT_SPECIFIC_ATTRIBUTES = [ + "multiple_of", + "ffn_dim_multiplier", + "depth_init", + "use_flex_attn", + "attn_mask_type", + ] + + def __init__( + self, + titan_dense_args, + # HuggingFace specific args + attn_implementation: str = "sdpa_torchtitan", + **kwargs, + ): + super().__init__(attn_implementation=attn_implementation, **kwargs) + assert titan_dense_args is not None, "titan_dense_args is required" + + # Create getter/setter dynamically for TT <-> HF attribute mappings + self._create_getter_setter_dynamically(has_moe=False) + + self._titan_injected_model_args = {} + self._configure_hf_attention(attn_implementation) + + self._initialize_dense_attributes(titan_dense_args) + + def _initialize_dense_attributes(self, titan_dense_args): + """Initialize all dense model attributes.""" + # Set mapped attributes (TorchTitan <-> HuggingFace) + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(titan_dense_args, titan_name): + value = getattr(titan_dense_args, titan_name) + setattr(self, hf_name, value) + + # Set TorchTitan-only attributes + for attr_name in self._TT_SPECIFIC_ATTRIBUTES: + if hasattr(titan_dense_args, attr_name): + setattr(self, attr_name, getattr(titan_dense_args, attr_name)) + + # Update passed_args + self._titan_injected_model_args.update(titan_dense_args.__dict__) + + def _configure_hf_attention(self, attn_implementation: str): + """Configure HuggingFace attention settings.""" + self._titan_injected_model_args["attn_implementation"] = attn_implementation + self.attn_implementation = attn_implementation + # NOTE:(3outeille):This will force create_causal_mask to return None + AttentionInterface._global_mapping[attn_implementation] = sdpa_attention_forward + + def _create_getter_setter_dynamically(self, has_moe: bool): + """ + Create properties dynamically based on tt and hf attribute mappings. + For example, creates a property 'dim' that reads/writes to 'hidden_size'. + """ + + def _create_property(hf_name: str) -> property: + def getter(self): + return getattr(self, hf_name) + + def setter(self, value): + setattr(self, hf_name, value) + + return property(getter, setter) + + # Setup attribute mappings + self._tt_to_hf_attribute_map = dict(self._TT_TO_HF_MAPPINGS["dense"]) + if has_moe: + self._tt_to_hf_attribute_map.update(self._TT_TO_HF_MAPPINGS["moe"]) + + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + # Create getter/setter for attribute that don't already exist + if not hasattr(self.__class__, titan_name): + setattr(self.__class__, titan_name, _create_property(hf_name)) + + def __repr__(self) -> str: + # HFTransformerModelArgs is a dataclass that also inherits from PretrainedConfig. + # PretrainedConfig has a __repr__ that serializes the object to JSON, but it + # doesn't work well with how HFTransformerModelArgs is initialized. + # This custom __repr__ provides a dataclass-like representation that correctly + # displays the arguments passed during initialization. + args_lines = [ + f"{k}={getattr(self, k)!r}" + for k in sorted(self._titan_injected_model_args.keys()) + if hasattr(self, k) + ] + args_str = "\n".join(args_lines) + return f"{self.__class__.__name__}(\n{args_str}\n)" + + def update_from_config(self, job_config: JobConfig): + # Load HF config (overwrites our HF attributes) + hf_model_config = AutoConfig.from_pretrained( + job_config.hf_transformers.model, + attn_implementation=self.attn_implementation, + trust_remote_code=True, + ) + + # Explicitly update attributes based on mappings + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(hf_model_config, hf_name): + setattr(self, titan_name, getattr(hf_model_config, hf_name)) + + # Copy any other attributes that might not be in the mapping + for key, value in hf_model_config.to_dict().items(): + setattr(self, key, value) + + # Update our attributes with the passed args from flavors + for key, value in self._titan_injected_model_args.items(): + if hasattr(self, key) and value is not None: + setattr(self, key, value) + + self.max_seq_len = job_config.training.seq_len + + self.deterministic = job_config.debug.deterministic + + # Configure HF-specific settings to match TorchTitan settings + # TODO: false ? + self.attention_bias = False + self.mlp_bias = False + self.use_cache = False + self.initializer_range = 1.0 # use as std for normal init in embedding + + if not hasattr(self, "inter_dim"): # Only for llama model + ffn_hidden_size = 4 * self.dim + ffn_hidden_size = int(2 * ffn_hidden_size / 3) + if self.ffn_dim_multiplier is not None: + ffn_hidden_size = int(self.ffn_dim_multiplier * ffn_hidden_size) + self.intermediate_size = self.multiple_of * ( + (ffn_hidden_size + self.multiple_of - 1) // self.multiple_of + ) + + self.head_dim = self.dim // self.num_attention_heads + + return self + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + return get_dense_model_nparams_and_flops( + self, model, head_dims=self.head_dim, seq_len=seq_len + ) diff --git a/torchtitan/experiments/transformers_backend/model/model.py b/torchtitan/experiments/transformers_backend/model/model.py new file mode 100644 index 0000000000..b88fffc54b --- /dev/null +++ b/torchtitan/experiments/transformers_backend/model/model.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import math + +import torch +from torch import nn +from torch.nn import init +from torchtitan.tools.logging import logger +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + +from .args import HFTransformerModelArgs + + +class SliceableModuleDict(nn.ModuleDict): + """ + A ModuleDict that supports slicing like ModuleList. + Keys are expected to be string representations of integers (e.g., "0", "1", "2"). + """ + + def __getitem__(self, key): + if isinstance(key, slice): + # Handle slicing: convert slice to list of keys + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + sliced_keys = keys[key] + # Return a new SliceableModuleDict with the sliced items + return SliceableModuleDict({k: self[k] for k in sliced_keys}) + return super().__getitem__(key) + + def __iter__(self): + # Iterate over values in sorted order by key (as integers) + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + for key in keys: + yield self[key] + + def __len__(self): + return len(self._modules) + + +class HFTransformerModel(nn.Module): + def __init__(self, model_args: HFTransformerModelArgs): + super().__init__() + + # NOTE(3outeille): This prevents Hugging Face modeling from initializing ROPE (inv_freq) buffers to NaN. + # Needed when loading from seed checkpoint. + if hasattr(model_args, "deterministic") and model_args.deterministic: + torch.utils.deterministic.fill_uninitialized_memory = False + + # Try to import the model class dynamically from the transformers library if not found in globals + model_class_name = model_args.architectures[0] + model_cls = globals().get(model_class_name, None) + if model_cls is None: + try: + transformers_mod = importlib.import_module("transformers") + model_cls = getattr(transformers_mod, model_class_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Could not find model class '{model_class_name}' in globals or transformers. " + f"Make sure the class is available. Original error: {e}" + ) from e + + # Attempt to patch model weight initialization based on architecture type + try: + model_name_prefix = model_class_name.replace("ForCausalLM", "") + model_module = importlib.import_module(model_cls.__module__) + + attention_cls = getattr(model_module, f"{model_name_prefix}Attention", None) + mlp_cls = getattr(model_module, f"{model_name_prefix}MLP", None) + decoder_layer_cls = getattr( + model_module, f"{model_name_prefix}DecoderLayer", None + ) + + required_classes = { + "Attention": attention_cls, + "DecoderLayer": decoder_layer_cls, + } + + if all(required_classes.values()): + logger.info(f"Applying Llama-like patch for {model_name_prefix}") + self._patch_hf_llama_like( + decoder_layer_cls=decoder_layer_cls, + attention_cls=attention_cls, + mlp_cls=mlp_cls, # mlp_cls can be None + ) + else: + missing = [name for name, cls in required_classes.items() if not cls] + logger.warning( + f"Could not find required classes ({', '.join(missing)}) for {model_name_prefix}. " + "Skipping Llama-like patch." + ) + + except Exception as e: + logger.warning( + f"Failed to apply agnostic patch for {model_class_name} due to: {e}. " + "Weight initialization might not match TorchTitan." + ) + + self.model = model_cls(config=model_args) + self.max_seq_len = model_args.max_seq_len + self.cp_mesh = None + + # Convert ModuleList to ModuleDict to preserve original indices + # This ensures state dict keys match checkpoint keys + if isinstance(self.model.model.layers, nn.ModuleList): + self.model.model.layers = SliceableModuleDict( + {str(i): layer for i, layer in enumerate(self.model.model.layers)} + ) + + for layer in self.model.model.layers.values(): + layer.moe_enabled = False + + def set_cp_mesh(self, mesh): + self.cp_mesh = mesh + + def _patch_hf_llama_like(self, decoder_layer_cls, attention_cls, mlp_cls=None): + """ + This patch modifies a Hugging Face Llama-like model's weight initialization to match + the initialization scheme used in TorchTitan. This is crucial for ensuring + bit-for-bit reproducibility when converting checkpoints between the native + TorchTitan format and the Hugging Face format. + + The patch targets the following aspects of the model: + - `PreTrainedModel._initialize_weights`: Handles meta device initialization correctly. + - `PreTrainedModel._init_weights`: Implements TorchTitan's specific initialization + for attention, MLP, embedding, and layer norm layers. This includes depth-dependent + initialization for attention and MLP layers. + - `DecoderLayer.__init__`: Adds `layer_idx` to attention and MLP modules within + each decoder layer, which is required for the depth-dependent initialization. + """ + + _original_decoder_layer_init = decoder_layer_cls.__init__ + + def _decoder_layer_init_patched(self, config: PretrainedConfig, layer_idx: int): + _original_decoder_layer_init(self, config, layer_idx) + self.layer_idx = layer_idx + # Ensure both attention and mlp modules have layer_idx for depth-based init + if hasattr(self, "self_attn"): + self.self_attn.layer_idx = layer_idx + # some models might not have mlp in each layer + if hasattr(self, "mlp") and self.mlp is not None: + self.mlp.layer_idx = layer_idx + + def _initialize_weights_patched(self, module): + # NOTE(3outeille): monkey-patch PreTrainedModel to handle meta device initialization correctly + # The default _initialize_weights sets _is_hf_initialized = True even on a meta device, + # which prevents subsequent proper initialization. + if getattr(module, "_is_hf_initialized", False): + return + + for param in module.parameters(recurse=True): + if param.device.type == "meta": + return + + # If not on a meta device, call the original weight initialization + self._init_weights(module) + module._is_hf_initialized = True + + def _init_weights_patched(self, module): + """ + Patched version of _init_weights to match TorchTitan's initialization for Llama-like models. + `self` is a PreTrainedModel instance. + """ + config = self.config + # Build tuple of classes to check for layer_idx-based init_std calculation + layer_idx_classes = [attention_cls] + if mlp_cls: + layer_idx_classes.append(mlp_cls) + layer_idx_classes = tuple(layer_idx_classes) + + if isinstance(module, layer_idx_classes): + if not hasattr(module, "layer_idx"): + raise ValueError( + f"Module {module} does not have a layer_idx attribute" + ) + + layer_idx = module.layer_idx + + if hasattr(config, "depth_init") and config.depth_init: + init_std = 0.02 / (2 * (layer_idx + 1)) ** 0.5 + else: + init_std = 0.02 / (2 * config.num_hidden_layers) ** 0.5 + + if isinstance(module, attention_cls): + # Initialize weights and biases for q, k, v projections + for proj_name in ["q_proj", "k_proj", "v_proj"]: + proj = getattr(module, proj_name) + nn.init.trunc_normal_(proj.weight, mean=0.0, std=0.02) + if proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(proj.bias, -bound, bound) + + # Handle different names for the output projection layer + o_proj = getattr(module, "o_proj", getattr(module, "dense", None)) + if o_proj is not None: + nn.init.trunc_normal_(o_proj.weight, mean=0.0, std=init_std) + if o_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(o_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(o_proj.bias, -bound, bound) + + elif mlp_cls and isinstance(module, mlp_cls): + # Handle different names for MLP layers + gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None)) + up_proj = getattr(module, "up_proj", None) + down_proj = getattr(module, "down_proj", getattr(module, "fc2", None)) + + # gate_proj (or fc1) should always use std=0.02 for numerical stability. + if gate_proj is not None: + nn.init.trunc_normal_(gate_proj.weight, mean=0.0, std=0.02) + if gate_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(gate_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(gate_proj.bias, -bound, bound) + # up_proj and down_proj (or fc2) use the depth-dependent init_std. + if up_proj is not None: + nn.init.trunc_normal_(up_proj.weight, mean=0.0, std=init_std) + if up_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(up_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(up_proj.bias, -bound, bound) + if down_proj is not None: + nn.init.trunc_normal_(down_proj.weight, mean=0.0, std=init_std) + if down_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(down_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(down_proj.bias, -bound, bound) + + elif module is getattr( + self, "lm_head", None + ): # TODO(3outeille): find a better way to detect lm_head + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + # When tie_word_embeddings is True, use lm_head initialization + if ( + hasattr(config, "tie_word_embeddings") + and config.tie_word_embeddings + ): + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + else: + std = config.initializer_range + module.weight.data.normal_(mean=0.0, std=std) + + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif ( + isinstance( + module, + (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d), + ) + or "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + # Norms can exist without weights (in which case they are None from torch primitives) + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + decoder_layer_cls.__init__ = _decoder_layer_init_patched + PreTrainedModel._init_weights = _init_weights_patched + PreTrainedModel._initialize_weights = _initialize_weights_patched + + @property + def tok_embeddings(self): + """Returns the model's embed_tokens, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + return self.model.model.embed_tokens + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @tok_embeddings.setter + def tok_embeddings(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + self.model.model.embed_tokens = value + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @property + def layers(self): + """Returns the model's layers, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + return self.model.model.layers + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @layers.setter + def layers(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + self.model.model.layers = value + else: + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @property + def norm(self): + """Returns the model's norm, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + return self.model.model.norm + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + return self.model.model.final_layernorm + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @norm.setter + def norm(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + self.model.model.norm = value + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + self.model.model.final_layernorm = value + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @property + def output(self): + """Returns the model's output layer, handling different Hugging Face model structures.""" + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + return self.model.lm_head + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @output.setter + def output(self, value): + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + self.model.lm_head = value + else: + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @property + def rotary_emb(self): + """Returns the model's rotary_emb, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + return self.model.model.rotary_emb + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + @rotary_emb.setter + def rotary_emb(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + self.model.model.rotary_emb = value + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + def forward(self, *args, **kwargs): + local_seq_len = self.max_seq_len + local_seq_len //= ( + self.cp_mesh.size() + if self.cp_mesh is not None and self.cp_mesh.size() > 1 + else 1 + ) + kwargs["position_ids"] = torch.arange( + local_seq_len, device=args[0].device + ).unsqueeze(0) + output = self.model.model(*args, **kwargs) + output = self.model.lm_head(output.last_hidden_state) + return output + + def init_weights(self, *args, **kwargs): + # This method replicates the behavior of the original PreTrainedModel.init_weights, + # but with a custom weight initialization function that skips nn.Identity modules (when PP is enabled) + + if self.model.config.pruned_heads: + logger.info("Pruning heads as per model configuration.") + self.model.prune_heads(self.model.config.pruned_heads) + + original_init_weights_fn = self.model._init_weights + + def selective_init(module): + # For pipeline parallel, we need to skip nn.Identity modules + if not isinstance(module, nn.Identity): + original_init_weights_fn(module) + else: + logger.info("Skipping nn.Identity module during weight initialization.") + + self.model.apply(selective_init) + + # TODO(3outeille): For pipeline parallel, only tie weights if both input and output embeddings are on the same device + # Maybe better way of handling this? + if not isinstance(self.tok_embeddings, nn.Identity) and not isinstance( + self.output, nn.Identity + ): + self.model.tie_weights() + + def named_children(self): + """ + Provides a flattened view of the model's main components, + making it compatible with TorchTitan's expectations. + """ + yield "tok_embeddings", self.tok_embeddings + yield "layers", self.layers + yield "norm", self.norm + yield "output", self.output + yield "rotary_emb", self.rotary_emb + + def __setattr__(self, name, value): + # If a property with a setter exists for this name, use it. + # This is to bypass the nn.Module.__setattr__ logic that + # directly registers modules and skips property setters. + cls = self.__class__ + if hasattr(cls, name): + prop = getattr(cls, name) + if isinstance(prop, property) and prop.fset is not None: + prop.fset(self, value) + return + + # Otherwise, fall back to the default nn.Module behavior. + super().__setattr__(name, value) diff --git a/torchtitan/experiments/transformers_backend/tests/integration_tests.py b/torchtitan/experiments/transformers_backend/tests/integration_tests.py new file mode 100644 index 0000000000..35d09d6a94 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/tests/integration_tests.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_transformers_backend_test_list() -> list[OverrideDefinitions]: + """ + key is the config file name and value is a list of OverrideDefinitions + that is used to generate variations of integration tests based on the + same root config file. + """ + integration_tests_flavors = [ + OverrideDefinitions( + [ + [ + "--model.name transformers_backend", + "--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config", + "--hf_transformers.model Qwen/Qwen2.5-7B", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule 1F1B", + ], + ], + "Transformers Backend FSDP+TP+PP", + "transformers_backend_fsdp+tp+pp", + ngpu=8, + ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "transformers_backend": build_transformers_backend_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["transformers_backend"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..d157a3a307 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -474,11 +474,17 @@ def forward_backward_step( ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage + cp_buffers = [inputs, labels] + cp_seq_dims = [1, 1] + if hasattr(model_parts[0], "freqs_cis"): + cp_buffers += [m.freqs_cis for m in model_parts] + cp_seq_dims += [0 for _ in model_parts] + optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_buffers=cp_buffers, + cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) From f8fa21e1d6fcd9d3d0cf3d490c265a9a490340ae Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Fri, 21 Nov 2025 17:46:45 -0500 Subject: [PATCH 030/127] adding variable length attention to llama3 8b (#2000) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary** This PR adds variable length attention (varlen) support to the Llama 3 8b model in torchtitan. We replace `use_flex_attn` with `attn_type` (either "sdpa", "varlen", "flex"). If `attn_type = "varlen"`, the attention module calls a compiled `varlen_attn` defined [here](https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/varlen.py). **Testing** Ran loss and performance tests against flex attention. Loss is on par. Screenshot 2025-11-19 at 3 24 26 PM Varlen is slightly slower than Flex due to the cuda kernel speeds (varlen calls into `flash_attention_forward`/`flash_attention_backward` today). | | Varlen | Flex | | :---: | :------ | :---: | | Forward | 774us 357ns | 722us 317ns | | Backward | 1ms 955us 916ns | 1ms 558us 747ns | --- tests/integration_tests/features.py | 12 ++ torchtitan/experiments/forge/example_train.py | 2 +- .../experiments/gpt_oss/infra/parallelize.py | 6 +- torchtitan/experiments/gpt_oss/model/args.py | 4 +- .../simple_fsdp/deepseek_v3/parallelize.py | 6 +- .../simple_fsdp/llama3/parallelize.py | 3 +- .../experiments/vlm/infra/parallelize.py | 7 +- torchtitan/experiments/vlm/model/args.py | 2 +- torchtitan/models/attention.py | 111 +++++++++++++++++- torchtitan/models/deepseek_v3/__init__.py | 8 +- .../models/deepseek_v3/infra/parallelize.py | 26 ++-- torchtitan/models/deepseek_v3/model/args.py | 10 +- torchtitan/models/deepseek_v3/model/model.py | 28 +++-- torchtitan/models/llama3/__init__.py | 33 +++++- torchtitan/models/llama3/infra/parallelize.py | 6 +- torchtitan/models/llama3/model/args.py | 8 +- torchtitan/models/llama3/model/model.py | 74 +++++++++--- torchtitan/models/llama4/__init__.py | 6 +- torchtitan/models/llama4/infra/parallelize.py | 6 +- torchtitan/models/llama4/model/args.py | 11 +- torchtitan/models/llama4/model/model.py | 15 +-- torchtitan/models/qwen3/infra/parallelize.py | 8 +- torchtitan/models/qwen3/model/args.py | 2 +- torchtitan/models/qwen3/model/model.py | 24 ++-- torchtitan/protocols/model.py | 3 +- torchtitan/train.py | 3 +- 26 files changed, 304 insertions(+), 120 deletions(-) diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 8bf3a0249f..6fafa29871 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -346,6 +346,18 @@ def build_features_test_list() -> list[OverrideDefinitions]: "fsdp+flex_attn+per_op_sac", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=4", + "--activation_checkpoint.mode='full'", + "--model.flavor=debugmodel_varlen_attn", + ] + ], + "FSDP+VARLEN_ATTN", + "fsdp+varlen_attn", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7b0b0c81e9..66ad151dd0 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -161,7 +161,7 @@ def forward_backward_step( inputs = input_dict["input"] extra_kwargs = {} - if getattr(self.model_args, "use_flex_attn", False): + if getattr(self.model_args, "attn_type", "sdpa") == "flex": extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 7714d497e4..1070f58aad 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -62,10 +62,6 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -111,6 +107,8 @@ def parallelize_gptoss( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index e78eac4d74..af4c51eadc 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads (int): Number of key-value heads. sliding_window_size (int): Size of the sliding attention window. attn_mask_type (str): Type of basic attention mask. - use_flex_attn (bool): Whether to use FlexAttention. Only supports True. + attn_type (bool): Attention type, only supports Flex. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. @@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads: int = 8 sliding_window_size: int = 128 attn_mask_type: str = "causal" - use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention + attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention # yarn original_seq_len: int = 4096 rope_theta: float = 150000.0 diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 6d415004cc..83e24d7dc1 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -67,9 +67,9 @@ def parallelize_deepseekv3( if ( job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn + and model.model_args.attn_type != "sdpa" ): - raise NotImplementedError("CP support for FlexAttention is still in progress.") + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -85,13 +85,11 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_non_moe_tp( model, world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 67a012a3f7..fb07ef617a 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -102,7 +102,8 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 6a97e4ece1..d418ad6edd 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -48,9 +48,9 @@ def parallelize_vlm( Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: raise NotImplementedError("TP support for VLM training is still in progress.") @@ -58,6 +58,7 @@ def parallelize_vlm( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py index 11b6439ddd..49ba31246b 100644 --- a/torchtitan/experiments/vlm/model/args.py +++ b/torchtitan/experiments/vlm/model/args.py @@ -53,7 +53,7 @@ class Siglip2ModelArgs: spatial_merge_size: int = 1 layer_norm_eps: float = 1e-6 - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "causal" diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 85115fef2b..cc7b87cb20 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -8,7 +8,7 @@ import functools from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, NamedTuple import torch import torch.nn.functional as F @@ -20,10 +20,14 @@ flex_attention, ) +from torch.nn.attention.varlen import varlen_attn + __all__ = [ "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", + "VarlenAttentionWrapper", + "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", "get_sliding_window_mask_mod", @@ -32,6 +36,53 @@ ] +class VarlenMetadata(NamedTuple): + """ + Cumulative sequence positions for queries and keys/values. + + """ + + cu_seq_q: torch.Tensor + cu_seq_k: torch.Tensor + max_q: int + max_k: int + + +class VarlenAttentionWrapper(torch.nn.Module): + _compiled_varlen_attn: ClassVar[Callable] = torch.compile( + varlen_attn, mode="max-autotune-no-cudagraphs" + ) + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + head_dim: torch.Tensor, + attention_masks: VarlenMetadata, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + cu_seq_q = attention_masks.cu_seq_q + cu_seq_k = attention_masks.cu_seq_k + max_q = attention_masks.max_q + max_k = attention_masks.max_k + + n_local_heads = xq.shape[1] + xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + + return VarlenAttentionWrapper._compiled_varlen_attn( + xq_packed, + xk_packed, + xv_packed, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal=True, + ) + + class FlexAttentionWrapper(torch.nn.Module): """Wrapper around `flex_attention` to make it torch.compile and CP compatible. @@ -66,7 +117,6 @@ def forward( # `FlexAttentionWrapper._compiled_flex_attn` is correct. # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation # to convert `lse` to be DTensor. - return FlexAttentionWrapper._compiled_flex_attn( q, k, @@ -226,3 +276,60 @@ def create_attention_mask(*args, **kwargs): arguments. """ return _compiled_create_block_mask(*args, **kwargs) + + +def create_varlen_metadata_for_document( + input_batch: torch.Tensor, eos_id: int +) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch + eos_id: the EOS id marker + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + max_seqlen = 0 + + for b in range(batch_size): + tokens = input_batch[b] + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + eos_positions + 1, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] + ) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..7e2d35a5d9 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -72,7 +72,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( @@ -124,7 +124,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( @@ -151,7 +151,7 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..a7e1ee0dc5 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -61,9 +61,10 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -84,7 +85,6 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -181,7 +181,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -211,18 +210,11 @@ def apply_non_moe_tp( PrepareModuleInput, ) - if use_flex_attn: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) - else: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..e683905878 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -44,7 +44,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. - use_flex_attn (bool): Whether to use FlexAttention. + attn_type (str): Attention type. attn_mask_type (str): Type of attention mask. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. @@ -76,7 +76,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # yarn @@ -101,10 +101,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..7d7635a4ad 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -184,11 +184,12 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, @@ -245,14 +246,15 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask) - output = self.inner_attention( - q, k, v, block_mask=attention_masks, scale=self.softmax_scale - ) - else: - assert attention_masks is None - output = self.inner_attention(q, k, v, scale=self.softmax_scale) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + case _: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 191588ad9e..75ab234ebc 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -36,7 +36,16 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - use_flex_attn=True, + attn_type="flex", + attn_mask_type="block_causal", + ), + "debugmodel_varlen_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + attn_type="varlen", attn_mask_type="block_causal", ), "8B": TransformerModelArgs( @@ -48,6 +57,28 @@ multiple_of=1024, rope_theta=500000, ), + "8B_flex": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + attn_type="flex", + attn_mask_type="block_causal", + ), + "8B_varlen": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + attn_type="varlen", + attn_mask_type="block_causal", + ), "70B": TransformerModelArgs( dim=8192, n_layers=80, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 86ac3a6dfe..b517e5c15f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -67,10 +67,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -95,6 +91,8 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index d83fb83102..81680074eb 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from torch import nn - from torchtitan.config import JobConfig from torchtitan.models.utils import get_dense_model_nparams_and_flops from torchtitan.protocols.model import BaseModelArgs @@ -43,7 +42,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 0 @@ -55,7 +54,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): raise NotImplementedError( "CP support for FlexAttention is still in progress." ) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 124153f14c..74b862bf76 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -16,10 +16,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -191,11 +194,14 @@ def __init__(self, model_args: TransformerModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "varlen": + self.inner_attention = VarlenAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -240,16 +246,24 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - assert ( - isinstance(attention_masks, BlockMask) or attention_masks is None - ), attention_masks - - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case "varlen": + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + ) + case "sdpa": + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.transpose( 1, 2 @@ -453,13 +467,14 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: case "causal": B = 1 @@ -470,10 +485,36 @@ def get_attention_masks( raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) + return create_attention_mask( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer, extra_inputs + ) + case "varlen": + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id + ) + case _: + raise NotImplementedError( + "Only varlen and flex attn masks are supported" + ) + def forward( self, tokens: torch.Tensor, @@ -497,7 +538,6 @@ def forward( for layer in self.layers.values(): h = layer(h, self.freqs_cis, attention_masks=attention_masks) - h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..b8bd9a4484 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -67,7 +67,7 @@ rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx16e_irope": TransformerModelArgs( @@ -83,7 +83,7 @@ moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx128e_irope": TransformerModelArgs( @@ -96,7 +96,7 @@ rope_theta=500000, moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9911ecdfd0..01ce9d543b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -75,10 +75,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -117,6 +113,8 @@ def parallelize_llama( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 7fcc9871f5..a277ca382e 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is @@ -76,10 +76,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): + raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index c8241b84de..6b9d2d2d9e 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -202,11 +202,12 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -217,7 +218,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - attention_masks: AttentionMasksType | None, + attention_masks: AttentionMasksType, ): """ Forward pass of the attention module. @@ -252,7 +253,7 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: + if self.attn_type == "flex": assert isinstance(attention_masks, dict), attention_masks attention_mask = attention_masks["rope" if self.use_rope else "nope"] output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..74254081b6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -59,9 +59,9 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -112,7 +112,7 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + use_flex_attn=attn_type == "flex", op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..2def3a949a 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,7 +36,7 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 151645 diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a4f0a59844..89296ed98d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -143,7 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 - self.use_flex_attn = getattr(model_args, "use_flex_attn", False) + self.attn_type = getattr(model_args, "attn_type", "sdpa") # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -167,10 +167,11 @@ def __init__(self, model_args: Qwen3ModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -226,12 +227,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + case _: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65b..4cb193c31a 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -16,9 +16,10 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig +from torchtitan.models.attention import VarlenMetadata -AttentionMasksType = dict[str, BlockMask] | BlockMask +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata @dataclass diff --git a/torchtitan/train.py b/torchtitan/train.py index d157a3a307..6f039b3c04 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -454,7 +454,8 @@ def post_dataloading_process( # extra_kwargs are. extra_kwargs: dict[str, Any] = {} - if getattr(self.model_args, "use_flex_attn", False): + attn_type = getattr(self.model_args, "attn_type", "sdpa") + if attn_type in ["flex", "varlen"]: extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, From e1f7f318a36d7143929e6a7f3b6c47314a83d985 Mon Sep 17 00:00:00 2001 From: Garrett Goon <44747910+garrett361@users.noreply.github.com> Date: Fri, 21 Nov 2025 19:55:04 -0500 Subject: [PATCH 031/127] remove scatter_add in MoE implementation (#1974) PR for removing `scatter_add` in the MoE implementation. `scatter_add` is somewhat problematic as it is non-deterministic due to the necessity of [atomic adds](https://discuss.pytorch.org/t/why-does-index-add-and-scatter-add-induce-non-deterministic-behavior-on-the-cuda-backend/45544/2) for correctness. Determinism, correctness, and performance tests using scripts under `torchtitan/moe_bench_and_test`: ``` # Determinism: run same forward 100x and compute standard deviations pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_determinism out_old_std=tensor(0.0297, device='cuda:0', dtype=torch.bfloat16) out_std=tensor(0., device='cuda:0', dtype=torch.bfloat16) out_old_std/out_moe_old.abs().mean()=tensor(0.0006, device='cuda:0', dtype=torch.bfloat16) out_std/out_moe.abs().mean()=tensor(0., device='cuda:0', dtype=torch.bfloat16) ``` ``` # Accuracy: compare MoE outputs to FFN outputs, with weights set such that outputs should be the same # Relative error decreased by 3x pytest -rsfP torchtitan/moe_bench_and_test/test_moe.py -k test_moe_ffn_equivalence moe_old_rel_err=0.009754068047048696 moe_rel_err=0.002507858727736454 moe_old_rel_err/moe_rel_err=3.8894009216589858 ``` ``` # Timing: triton do_bench for DSv3 16B layer fwd + bwd. ~3% faster runtime python torchtitan/moe_bench_and_test/moe_timing.py moe_old && python torchtitan/moe_bench_and_test/moe_timing.py moe args=Namespace(cls='moe_old', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4) moe_time_ms=19.712812881469727 args=Namespace(cls='moe', perf_reps=1000, perf_warmups=100, seqlen=4096, bsz=4) moe_time_ms=19.03301840562087 ``` ``` # Memory: for DSv3 16B layer fwd + bwd. ~15% reduction in active mem, ~18% in reserved mem. python torchtitan/moe_bench_and_test/moe_memory.py moe_old && python torchtitan/moe_bench_and_test/moe_memory.py moe args=Namespace(cls='moe_old', iters=1, seqlen=4096, bsz=4) peak_stats.max_active_gib=5.926029682159424 peak_stats.max_reserved_gib=7.224609375 args=Namespace(cls='moe', iters=1, seqlen=4096, bsz=4) peak_stats.max_active_gib=5.051033020019531 peak_stats.max_reserved_gib=5.91015625 ``` Testing fwd + bwd correctness for `tp_degree=ep_degree=world_size=8` and `etp=1` ``` # Similar relative errors torchrun --nproc-per-node 8 torchtitan/moe_bench_and_test/test_tp.py args=Namespace(seqlen=256, bsz=4, tol=0.01), world_size=8, tp=8, ep=8, etp=1 err_ratio_fsdp_ep_old=0.0028211805268959435 err_ratio_fsdp_ep=0.002805679534989922 err_ratio_ep_ep_old=0.0022941468020912068 kl_fsdp_ep_old=tensor(2.4915e-05, device='cuda:0', dtype=torch.bfloat16) kl_fsdp_ep=tensor(2.0981e-05, device='cuda:0', dtype=torch.bfloat16) kl_ep_ep_old=tensor(2.1458e-05, device='cuda:0', dtype=torch.bfloat16) ``` Everything under `torchtitan/moe_bench_and_test` is temporary testing utilities and is to be deleted prior to merging. --- torchtitan/distributed/expert_parallel.py | 9 ++--- torchtitan/models/moe/moe.py | 49 +++++++++++++---------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index e9986b9974..b78019e057 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -264,12 +264,9 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): # NOTE: As we shard routed tokens along bs*slen dim across the TP ranks, # the MoE gather and scatter still require global token indices. local_rank = device_mesh.get_local_rank() - # fact: top_scores.shape[0] // mod.top_k = batch_size * seq_len // ep_degree - if not hasattr(mod, "top_k"): - raise ValueError( - "TokenReorderer class in MoE should always have top_k attribute." - ) - token_indices_experts_sorted += top_scores.shape[0] // mod.top_k * local_rank + token_indices_experts_sorted = ( + token_indices_experts_sorted + top_scores.shape[0] * local_rank + ) return top_scores, token_indices_experts_sorted, num_tokens_per_expert diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a5..741c908eab 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -345,7 +345,6 @@ def forward( ) top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k return ( top_scores_experts_sorted, @@ -414,7 +413,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # top_scores and selected_experts_indices shape (bs*slen, top_k) # num_tokens_per_expert shape (num_experts,) ( top_scores, @@ -430,7 +429,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) # NOTE: the reason we need to compute num_tokens_per_expert again is: # 1st computation in router is to update self.tokens_per_expert @@ -445,12 +444,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.reorderer(top_scores, selected_experts_indices) # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + routed_input = x[token_indices_experts_sorted // self.router.top_k] if self.score_before_experts: routed_input = ( @@ -464,22 +458,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication - if self.shared_experts is not None: - out = self.shared_experts(x) - else: - out = torch.zeros_like(x) + out = self.shared_experts(x) if self.shared_experts is not None else None + # Unsort routed outputs + routed_output_unsorted = torch.zeros( + (bs * slen * self.router.top_k, dim), + dtype=routed_output.dtype, + device=routed_output.device, + ) + routed_output_unsorted[token_indices_experts_sorted] = routed_output + routed_output_unsorted = routed_output_unsorted.reshape( + -1, self.router.top_k, dim + ) if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + out_experts = ( + torch.bmm( + top_scores.reshape(-1, 1, self.router.top_k), + routed_output_unsorted.float(), + ) + .to(x.dtype) + .squeeze(1) + ) + else: + out_experts = routed_output_unsorted.sum(dim=1) - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) - return out + if out is None: + return out_experts.reshape(bs, slen, dim) + return (out + out_experts).reshape(bs, slen, dim) def init_weights( self, From ad9f188abe816de08f68a2aaf8b97bed81b83e64 Mon Sep 17 00:00:00 2001 From: Ferdinand Mom <47445085+3outeille@users.noreply.github.com> Date: Sun, 23 Nov 2025 06:42:29 +0100 Subject: [PATCH 032/127] Update transformers backend name (#2075) following Huggingface efforts in VLLM (cf https://github.com/vllm-project/vllm/pull/28725), we would like to uniformize the naming and make sure that people think we use the HF models only --- .ci/docker/common/install_conda.sh | 2 +- ...> requirements-transformers-modeling-backend.txt} | 0 .ci/docker/ubuntu/Dockerfile | 2 +- ...ion_test_8gpu_transformers_modeling_backend.yaml} | 8 ++++---- torchtitan/experiments/README.md | 2 +- torchtitan/experiments/__init__.py | 2 +- .../README.md | 10 ++++++---- .../__init__.py | 0 .../configs/debug_model.toml | 2 +- .../configs/full.toml | 2 +- .../infra/parallelize.py | 2 +- .../infra/pipeline.py | 2 +- .../job_config.py | 0 .../model/args.py | 0 .../model/model.py | 0 .../tests/integration_tests.py | 12 ++++++------ 16 files changed, 24 insertions(+), 22 deletions(-) rename .ci/docker/{requirements-transformers-backend.txt => requirements-transformers-modeling-backend.txt} (100%) rename .github/workflows/{integration_test_8gpu_transformers_backend.yaml => integration_test_8gpu_transformers_modeling_backend.yaml} (82%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/README.md (78%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/__init__.py (100%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/configs/debug_model.toml (98%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/configs/full.toml (98%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/infra/parallelize.py (99%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/infra/pipeline.py (99%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/job_config.py (100%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/model/args.py (100%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/model/model.py (100%) rename torchtitan/experiments/{transformers_backend => transformers_modeling_backend}/tests/integration_tests.py (83%) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index d3cb20e7a3..f47dc4cb8f 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -43,7 +43,7 @@ install_pip_dependencies() { pip_install -r /opt/conda/requirements.txt pip_install -r /opt/conda/requirements-flux.txt pip_install -r /opt/conda/requirements-vlm.txt - pip_install -r /opt/conda/requirements-transformers-backend.txt + pip_install -r /opt/conda/requirements-transformers-modeling-backend.txt popd } diff --git a/.ci/docker/requirements-transformers-backend.txt b/.ci/docker/requirements-transformers-modeling-backend.txt similarity index 100% rename from .ci/docker/requirements-transformers-backend.txt rename to .ci/docker/requirements-transformers-modeling-backend.txt diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index b8123099b9..dfb753c1f4 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -33,7 +33,7 @@ COPY requirements-dev.txt /opt/conda/ COPY requirements.txt /opt/conda/ COPY requirements-flux.txt /opt/conda/ COPY requirements-vlm.txt /opt/conda/ -COPY requirements-transformers-backend.txt /opt/conda/ +COPY requirements-transformers-modeling-backend.txt /opt/conda/ COPY conda-env-ci.txt /opt/conda/ COPY ./common/install_conda.sh install_conda.sh COPY ./common/utils.sh utils.sh diff --git a/.github/workflows/integration_test_8gpu_transformers_backend.yaml b/.github/workflows/integration_test_8gpu_transformers_modeling_backend.yaml similarity index 82% rename from .github/workflows/integration_test_8gpu_transformers_backend.yaml rename to .github/workflows/integration_test_8gpu_transformers_modeling_backend.yaml index aea5189d81..83ba588c68 100644 --- a/.github/workflows/integration_test_8gpu_transformers_backend.yaml +++ b/.github/workflows/integration_test_8gpu_transformers_modeling_backend.yaml @@ -1,13 +1,13 @@ -name: Transformers Backend 8 GPU Integration Tests +name: Transformers Modeling Backend 8 GPU Integration Tests on: push: branches: [ main ] paths: - - 'torchtitan/experiments/transformers_backend/**' + - 'torchtitan/experiments/transformers_modeling_backend/**' pull_request: paths: - - 'torchtitan/experiments/transformers_backend/**' + - 'torchtitan/experiments/transformers_modeling_backend/**' schedule: # Runs every 12 hours - cron: '0 */12 * * *' @@ -50,4 +50,4 @@ jobs: USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded - python -m torchtitan.experiments.transformers_backend.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 + python -m torchtitan.experiments.transformers_modeling_backend.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 08dc692bf9..52d68ba784 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -31,4 +31,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | -| [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index db3a44a824..ec42919ab0 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -12,6 +12,6 @@ "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", - "transformers_backend", + "transformers_modeling_backend", ] ) diff --git a/torchtitan/experiments/transformers_backend/README.md b/torchtitan/experiments/transformers_modeling_backend/README.md similarity index 78% rename from torchtitan/experiments/transformers_backend/README.md rename to torchtitan/experiments/transformers_modeling_backend/README.md index 805afb9ab9..fb70d03a1f 100644 --- a/torchtitan/experiments/transformers_backend/README.md +++ b/torchtitan/experiments/transformers_modeling_backend/README.md @@ -1,15 +1,17 @@ -# Huggingface Transformers backend +# Huggingface Transformers Modeling backend + +This enables HF transformers models to be trained with `4D parallelism + torch.compile` ## Quick start - Requirements `transformers==4.57.1` -- Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` +- Config: `torchtitan/torchtitan/experiments/transformers_modeling_backend/configs/qwen3.toml` ```diff ... [model] - name = "llama3" -+ name = "transformers_backend" ++ name = "transformers_modeling_backend" flavor = "debugmodel" hf_assets_path = "./tests/assets/tokenizer" @@ -17,7 +19,7 @@ hf_assets_path = "./tests/assets/tokenizer" +model = "Qwen/Qwen3-4B-Instruct-2507" ... ``` -- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` +- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_modeling_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_modeling_backend.job_config --compile.enable` - Make sure you have created the tokenizers beforehand image diff --git a/torchtitan/experiments/transformers_backend/__init__.py b/torchtitan/experiments/transformers_modeling_backend/__init__.py similarity index 100% rename from torchtitan/experiments/transformers_backend/__init__.py rename to torchtitan/experiments/transformers_modeling_backend/__init__.py diff --git a/torchtitan/experiments/transformers_backend/configs/debug_model.toml b/torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml similarity index 98% rename from torchtitan/experiments/transformers_backend/configs/debug_model.toml rename to torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml index 7b3de04b87..0775ead39b 100644 --- a/torchtitan/experiments/transformers_backend/configs/debug_model.toml +++ b/torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml @@ -20,7 +20,7 @@ save_tb_folder = "tb" enable_wandb = false [model] -name = "transformers_backend" +name = "transformers_modeling_backend" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" diff --git a/torchtitan/experiments/transformers_backend/configs/full.toml b/torchtitan/experiments/transformers_modeling_backend/configs/full.toml similarity index 98% rename from torchtitan/experiments/transformers_backend/configs/full.toml rename to torchtitan/experiments/transformers_modeling_backend/configs/full.toml index 45eaa785de..34ec994fb1 100644 --- a/torchtitan/experiments/transformers_backend/configs/full.toml +++ b/torchtitan/experiments/transformers_modeling_backend/configs/full.toml @@ -20,7 +20,7 @@ save_tb_folder = "tb" enable_wandb = false [model] -name = "transformers_backend" +name = "transformers_modeling_backend" flavor = "full" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" diff --git a/torchtitan/experiments/transformers_backend/infra/parallelize.py b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py similarity index 99% rename from torchtitan/experiments/transformers_backend/infra/parallelize.py rename to torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py index b2ae3f02a1..a049d88d76 100644 --- a/torchtitan/experiments/transformers_backend/infra/parallelize.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py @@ -22,7 +22,7 @@ from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.experiments.transformers_modeling_backend.job_config import JobConfig from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/transformers_backend/infra/pipeline.py b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py similarity index 99% rename from torchtitan/experiments/transformers_backend/infra/pipeline.py rename to torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py index 04452c5ede..f05caf9abf 100644 --- a/torchtitan/experiments/transformers_backend/infra/pipeline.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py @@ -21,7 +21,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline_parallel import build_pipeline_schedule -from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.experiments.transformers_modeling_backend.job_config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/transformers_backend/job_config.py b/torchtitan/experiments/transformers_modeling_backend/job_config.py similarity index 100% rename from torchtitan/experiments/transformers_backend/job_config.py rename to torchtitan/experiments/transformers_modeling_backend/job_config.py diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_modeling_backend/model/args.py similarity index 100% rename from torchtitan/experiments/transformers_backend/model/args.py rename to torchtitan/experiments/transformers_modeling_backend/model/args.py diff --git a/torchtitan/experiments/transformers_backend/model/model.py b/torchtitan/experiments/transformers_modeling_backend/model/model.py similarity index 100% rename from torchtitan/experiments/transformers_backend/model/model.py rename to torchtitan/experiments/transformers_modeling_backend/model/model.py diff --git a/torchtitan/experiments/transformers_backend/tests/integration_tests.py b/torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py similarity index 83% rename from torchtitan/experiments/transformers_backend/tests/integration_tests.py rename to torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py index 35d09d6a94..35df7bb86a 100644 --- a/torchtitan/experiments/transformers_backend/tests/integration_tests.py +++ b/torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py @@ -11,7 +11,7 @@ from tests.integration_tests.run_tests import run_tests -def build_transformers_backend_test_list() -> list[OverrideDefinitions]: +def build_transformers_modeling_backend_test_list() -> list[OverrideDefinitions]: """ key is the config file name and value is a list of OverrideDefinitions that is used to generate variations of integration tests based on the @@ -21,8 +21,8 @@ def build_transformers_backend_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name transformers_backend", - "--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config", + "--model.name transformers_modeling_backend", + "--job.custom_config_module=torchtitan.experiments.transformers_modeling_backend.job_config", "--hf_transformers.model Qwen/Qwen2.5-7B", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", @@ -31,7 +31,7 @@ def build_transformers_backend_test_list() -> list[OverrideDefinitions]: ], ], "Transformers Backend FSDP+TP+PP", - "transformers_backend_fsdp+tp+pp", + "transformers_modeling_backend_fsdp+tp+pp", ngpu=8, ), ] @@ -39,7 +39,7 @@ def build_transformers_backend_test_list() -> list[OverrideDefinitions]: _TEST_SUITES_FUNCTION = { - "transformers_backend": build_transformers_backend_test_list, + "transformers_modeling_backend": build_transformers_modeling_backend_test_list, } @@ -64,7 +64,7 @@ def main(): if os.listdir(args.output_dir): raise RuntimeError("Please provide an empty output directory.") - test_list = _TEST_SUITES_FUNCTION["transformers_backend"]() + test_list = _TEST_SUITES_FUNCTION["transformers_modeling_backend"]() run_tests(args, test_list) From c70310c5a60ddfcde602889e8d66967ea55ed213 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 24 Nov 2025 14:12:13 -0800 Subject: [PATCH 033/127] Enhance loss_compare.py: Add Import/Export Options and Enable CI Comparison with Existing Losses (#2063) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2063 This PR allows us to check if the loss is consistent across commits/PRs. 1. This PR contains a pre-tested losses result file. 2. This PR improve the loss_compare.py to add --import and --export options. 3. In CI, uses --import to get the previous losses and compare them with the current PR. If anything mismatch (10 steps), the CI will fail. --- .../integration_test_8gpu_features.yaml | 13 +- scripts/loss_compare.py | 149 +++++++++++++++++- tests/assets/losses/llama3.txt | 10 ++ 3 files changed, 161 insertions(+), 11 deletions(-) create mode 100644 tests/assets/losses/llama3.txt diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index de0672eeef..a20cd22545 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -90,13 +90,14 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 - - # Verify the accuracy. - echo "Checking FSDP4 v.s. HSDP2FSDP2TP2 accuracy parity" + # Verify the accuracy first. + echo "Checking FSDP8 v.s. HSDP (4, 2) accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" - export test_options="--parallelism.data_parallel_replicate_degree=2 --parallelism.tensor_parallel_degree=2" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --baseline-ngpus=4 --test-ngpus=8 --steps=1 + export test_options="--parallelism.data_parallel_replicate_degree=4" + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result tests/assets/losses/llama3.txt + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/* + + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # Cleanup the checkpoints so that we don't waste network bandwidth and time. rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 42ad3a81be..3479875036 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -168,6 +168,9 @@ def validate_arguments( test_train_file: str, test_options: str, steps: int, + assert_equal: bool, + export_result: str | None, + import_result: str | None, ) -> None: """Validate command line arguments.""" # Validate commit arguments - if one is ".", both must be "." @@ -201,6 +204,34 @@ def validate_arguments( log_print(f"Error: --steps must be a positive integer, got: {steps}") sys.exit(1) + # Validate export-result requires assert-equal + if export_result and not assert_equal: + log_print("Error: --export-result requires --assert-equal") + log_print(" Export only happens when losses are verified to match") + sys.exit(1) + + # Validate import-result requires assert-equal + if import_result and not assert_equal: + log_print("Error: --import-result requires --assert-equal") + log_print(" Import is used to verify all losses match") + sys.exit(1) + + # Validate export-result and import-result are mutually exclusive + if export_result and import_result: + log_print( + "Error: --export-result and --import-result cannot be " "used together" + ) + log_print( + " Use export to save results or import to compare " + "against saved results" + ) + sys.exit(1) + + # Validate import file exists + if import_result and not os.path.exists(import_result): + log_print(f"Error: Import file does not exist: {import_result}") + sys.exit(1) + # ============================================================================= # SETUP FUNCTIONS @@ -433,6 +464,34 @@ def read_losses_from_file(loss_file: str) -> dict[int, float]: return losses +def export_losses_to_file(losses: dict[int, float], export_path: str) -> None: + """Export losses to file and stdout. + + Args: + losses: Dictionary mapping step numbers to loss values + export_path: Path to export file + """ + log_print(f"Exporting losses to {export_path}") + + # Write to file and collect output for stdout + with open(export_path, "w") as f: + for step in sorted(losses.keys()): + loss = losses[step] + line = f"{step} {loss}" + f.write(line + "\n") + + log_print(f"Exported {len(losses)} loss values:") + log_print() + + # Output to stdout in same format + for step in sorted(losses.keys()): + loss = losses[step] + print(f"{step} {loss}") + + log_print() + log_print(f"Losses saved to: {export_path}") + + def extract_loss_data(output_folder: str | None) -> None: """Extract loss data from logs.""" if not output_folder: @@ -556,13 +615,18 @@ def perform_loss_analysis( generate_summary_statistics(baseline_losses, test_losses, stats_file) -def assert_losses_equal(baseline_log: str, test_log: str) -> None: - """Assert that losses are equal between baseline and test using - unittest. +def assert_losses_equal( + baseline_log: str, test_log: str, import_result: str | None = None +) -> None: + """Assert that losses are equal between baseline and test using unittest. + + If import_result is provided, also compares baseline with imported losses. """ log_print("Asserting losses are equal...") log_print(f"Baseline log: {baseline_log}") log_print(f"Test log: {test_log}") + if import_result: + log_print(f"Import file: {import_result}") # Extract losses from both logs baseline_losses = extract_losses_from_log(baseline_log) @@ -579,6 +643,15 @@ def assert_losses_equal(baseline_log: str, test_log: str) -> None: log_print("Error: No losses found in test log") sys.exit(1) + # Load imported losses if provided + imported_losses = None + if import_result: + imported_losses = read_losses_from_file(import_result) + log_print(f"Loaded {len(imported_losses)} steps from import file") + if not imported_losses: + log_print("Error: No losses found in import file") + sys.exit(1) + # Create a test case class LossEqualityTest(unittest.TestCase): def test_losses_equal(self): @@ -593,10 +666,22 @@ def test_losses_equal(self): f"test has {len(test_steps)} steps", ) + # If imported losses exist, check steps match + if imported_losses: + imported_steps = set(imported_losses.keys()) + self.assertEqual( + baseline_steps, + imported_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, " + f"imported has {len(imported_steps)} steps", + ) + # Check that losses are equal for each step for step in sorted(baseline_steps): baseline_loss = baseline_losses[step] test_loss = test_losses[step] + + # Compare baseline vs test self.assertEqual( baseline_loss, test_loss, @@ -604,6 +689,18 @@ def test_losses_equal(self): f"baseline={baseline_loss}, test={test_loss}", ) + # Compare baseline vs imported (if provided) + # No need to compare test vs imported since: + # baseline==test and baseline==imported implies test==imported + if imported_losses: + imported_loss = imported_losses[step] + self.assertEqual( + baseline_loss, + imported_loss, + f"Loss mismatch at step {step}: " + f"baseline={baseline_loss}, imported={imported_loss}", + ) + # Run the test suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest) runner = unittest.TextTestRunner(verbosity=2) @@ -613,7 +710,13 @@ def test_losses_equal(self): log_print("Loss assertion failed!") sys.exit(1) else: - log_print("All losses are equal. Assertion passed!") + if import_result: + log_print( + "All losses are equal (baseline, test, and imported). " + "Assertion passed!" + ) + else: + log_print("All losses are equal. Assertion passed!") def cleanup_temp_files(output_folder: str | None) -> None: @@ -756,6 +859,24 @@ def parse_arguments() -> argparse.Namespace: "Script exits with error if losses differ." ), ) + parser.add_argument( + "--export-result", + default="", + help=( + "Export losses to specified file path (requires --assert-equal). " + "Exports only when losses match. Format: '{step} {loss}' per line." + ), + ) + parser.add_argument( + "--import-result", + default="", + help=( + "Import losses from specified file path for comparison " + "(requires --assert-equal). " + "Compares imported losses with both baseline and test " + "(all 3 must match)." + ), + ) parser.add_argument( "--job-dump-folder", default="outputs", @@ -787,6 +908,14 @@ def parse_arguments() -> argparse.Namespace: if not args.output_folder: args.output_folder = None + # Convert empty export_result to None + if not args.export_result: + args.export_result = None + + # Convert empty import_result to None + if not args.import_result: + args.import_result = None + return args @@ -850,6 +979,9 @@ def main() -> None: args.test_train_file, args.test_options, args.steps, + args.assert_equal, + args.export_result, + args.import_result, ) # Setup environment @@ -912,7 +1044,14 @@ def main() -> None: # Assert losses are equal if requested if args.assert_equal: - assert_losses_equal(baseline_log, test_log) + # Pass import_result if provided for 3-way comparison + assert_losses_equal(baseline_log, test_log, args.import_result) + + # Export losses if requested (only after assertion passes) + if args.export_result: + # Extract baseline losses (they equal test losses since assertion passed) + baseline_losses = extract_losses_from_log(baseline_log) + export_losses_to_file(baseline_losses, args.export_result) # Analysis and reporting perform_loss_analysis(baseline_log, test_log, stats_file) diff --git a/tests/assets/losses/llama3.txt b/tests/assets/losses/llama3.txt new file mode 100644 index 0000000000..5ccea64b17 --- /dev/null +++ b/tests/assets/losses/llama3.txt @@ -0,0 +1,10 @@ +1 8.1376 +2 7.841 +3 7.1815 +4 6.3509 +5 5.5272 +6 4.9244 +7 4.5606 +8 4.3724 +9 4.347 +10 4.2004 From 7e1edb608b99643d7890ff7fe913f58f97b82120 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 24 Nov 2025 14:22:00 -0800 Subject: [PATCH 034/127] Print out the version number (#2083) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2083 This PR and https://github.com/pytorch/torchtitan/pull/2070 can resolve https://github.com/pytorch/torchtitan/issues/2043. This should not affect `.github/scripts/update_version.sh` as `.github/scripts/update_version.sh` will append the version at the end of the file, which will overwrite the value. --- torchtitan/__init__.py | 7 +++++++ torchtitan/train.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 176bce9b60..52c3ff3e22 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -4,5 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from importlib.metadata import version + # Import to register quantization modules. import torchtitan.components.quantization # noqa: F401 + +try: + __version__ = version("torchtitan") +except Exception as e: + __version__ = "0.0.0+unknown" diff --git a/torchtitan/train.py b/torchtitan/train.py index 6f039b3c04..1b4096abb6 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -718,6 +718,14 @@ def main(trainer_class: type[Trainer]) -> None: trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer) """ init_logger() + + import torchtitan + + logger.info( + "torchtitan version: %s (0.0.0 means __version__ is not defined correctly).", + torchtitan.__version__, + ) + config_manager = ConfigManager() config = config_manager.parse_args() trainer: Trainer | None = None From 7e10d6052a8029592a37d1c843dc7949a6b30043 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 24 Nov 2025 19:09:58 -0800 Subject: [PATCH 035/127] Autoparallel as an experiment in main (#2054) Experiments like SimpleFSDP/Compiler Toolkit/Autoparallel are all being developed at the same time, and SimpleFSDP/Compiler Toolkit both run into issues with PP that requires the PP utilities from Autoparallel. We want to land the Autoparallel experiment into main to facilitate that sharing. --------- Signed-off-by: Edward Z. Yang Co-authored-by: Will Constable Co-authored-by: Edward Z. Yang Co-authored-by: Francisco Massa Co-authored-by: ruisizhang123 Co-authored-by: Brian Hirsh Co-authored-by: Will Constable --- torchtitan/components/optimizer.py | 19 +- torchtitan/experiments/README.md | 1 + torchtitan/experiments/__init__.py | 2 + .../experiments/auto_parallel/README.md | 19 + .../auto_parallel/deepseek_v3/__init__.py | 50 ++ .../deepseek_v3/parallelize_deepseekv3.py | 441 ++++++++++++++++++ .../experiments/auto_parallel/job_config.py | 25 + .../auto_parallel/llama3/__init__.py | 37 ++ .../auto_parallel/llama3/parallelize_llama.py | 151 ++++++ 9 files changed, 740 insertions(+), 5 deletions(-) create mode 100644 torchtitan/experiments/auto_parallel/README.md create mode 100644 torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py create mode 100644 torchtitan/experiments/auto_parallel/job_config.py create mode 100644 torchtitan/experiments/auto_parallel/llama3/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7fc5098800..80557366da 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -16,6 +16,7 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import Replicate from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft @@ -380,11 +381,19 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): + tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( + placements=[Replicate()] + * tokens_per_expert_by_layer.device_mesh.ndim + ) + else: + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, + group=pg, + op=torch.distributed.ReduceOp.SUM, + ) moe_layer_idx = 0 with torch.no_grad(): diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 52d68ba784..aa93628656 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -32,3 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index ec42919ab0..7e2c442103 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -13,5 +13,7 @@ "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", "transformers_modeling_backend", + "auto_parallel.llama3", + "auto_parallel.deepseek_v3", ] ) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 0000000000..55dcc3c5e5 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,19 @@ +## Auto Parallel + +### Overview + +The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. + +### Requirements + +Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel) + +### Single Node + +**Llama3** + +`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` + +**DeepSeekv3** + +`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py new file mode 100644 index 0000000000..b90583c86b --- /dev/null +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + model_args = copy.deepcopy(deepseekv3_args) + + default_args = DeepSeekV3ModelArgs() + for config, args in model_args.items(): + if "flex_attn" in config: + continue + + args.attn_type = default_args.attn_type + args.attn_mask_type = default_args.attn_mask_type + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..fc278cfabe --- /dev/null +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +import types +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Replicate, Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.models.moe.moe import _run_experts_grouped_mm + +from torchtitan.tools.logging import logger + + +def create_functional_router_forward( + self: nn.Module, +) -> Callable: # TokenChoiceTopKRouter + def functional_router_forward( + x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor + ): + # scores shape (bs*slen, num_experts) + scores = F.linear(x, gate_weight) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.score_func == "sigmoid": + scores = torch.sigmoid(scores.to(torch.float32)) + elif self.score_func == "softmax": + scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_func}") + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + # debug override: balanced round-robin routing + if self._debug_force_load_balance: + ( + selected_experts_indices, + top_scores, + ) = self._debug_force_load_balance_routing(scores) + + if self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + return top_scores, selected_experts_indices, num_tokens_per_expert + + return functional_router_forward + + +def _moe_forward( + x: torch.Tensor, + router_gate_weight: torch.nn.Parameter, + expert_bias: Optional[torch.Tensor], + experts_w1: torch.Tensor, + experts_w3: torch.Tensor, + experts_w2: torch.Tensor, + shared_w1_weight: torch.Tensor, + shared_w3_weight: torch.Tensor, + shared_w2_weight: torch.Tensor, + functional_router_forward: Callable, + reorderer: nn.Module, # TokenReorderer + top_k: int, +): + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen, top_k) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = functional_router_forward(x, router_gate_weight, expert_bias) + num_tokens_per_expert_update = num_tokens_per_expert + + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + routed_input = x[token_indices_experts_sorted // top_k] + + # DSv3 score_before_experts is always False + # if score_before_experts: + # routed_input = ( + # routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + # ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + # routed_output = experts(routed_input, num_tokens_per_expert) + routed_output = _run_experts_grouped_mm( + experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert + ) + + # always has shared expert + # Note: we execute the shared expert before scoring the output of the routed expert + # to "implicitly" overlap the shared expert compute with token combine communication + _h1 = F.linear(x, shared_w1_weight) + _h3 = F.linear(x, shared_w3_weight) + out = F.linear(F.silu(_h1) * _h3, shared_w2_weight) + + # Unsort routed outputs + routed_output_unsorted = torch.zeros( + (bs * slen * top_k, dim), + dtype=routed_output.dtype, + device=routed_output.device, + ) + routed_output_unsorted[token_indices_experts_sorted] = routed_output + routed_output_unsorted = routed_output_unsorted.reshape(-1, top_k, dim) + # DSv3 score_before_experts is False + # if not self.score_before_experts: + out_experts = ( + torch.bmm( + top_scores.reshape(-1, 1, top_k), + routed_output_unsorted.float(), + ) + .to(x.dtype) + .squeeze(1) + ) + # else: + # out_experts = routed_output_unsorted.sum(dim=1) + + # always has shared experts + # if out is None: + return (out + out_experts).reshape(bs, slen, dim), num_tokens_per_expert_update + + +def moe_forward(self, x: torch.Tensor) -> torch.Tensor: + functional_router_forward = create_functional_router_forward(self.router) + out, num_tokens_per_expert = _moe_forward( + x, + self.router.gate.weight, + self.expert_bias, + self.experts.w1, + self.experts.w3, + self.experts.w2, + self.shared_experts.w1.weight, + self.shared_experts.w3.weight, + self.shared_experts.w2.weight, + functional_router_forward, + self.reorderer, + self.router.top_k, + ) + # HOPs don't support buffer mutations, keep this outside + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + return out + + +def monkey_patch_checks(moe): + # causes data-dependent issue, hardcoded into monkey patch + assert not moe.score_before_experts + assert moe.router.gate.bias is None + assert moe.experts.use_grouped_mm + assert moe.shared_experts is not None + assert moe.shared_experts.w1.bias is None + assert moe.shared_experts.w2.bias is None + assert moe.shared_experts.w3.bias is None + assert not list(moe.reorderer.parameters()) + assert not list(moe.reorderer.buffers()) + + +def monkey_patch_local_map_moe(model, world_mesh): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + """ + from torch.distributed._tensor.experimental import local_map + + # from torchtitan.models.moe import moe + global _moe_forward + _moe_forward = local_map( + _moe_forward, + out_placements=( + (Replicate(),), # out: torch.Tensor + (Replicate(),), # num_tokens_per_expert_update: torch.Tensor + ), + in_placements=( + (Replicate(),), # x: torch.Tensor, + (Replicate(),), # router_gate_weight: torch.nn.Parameter, + (Replicate(),), # expert_bias: Optional[torch.Tensor], + (Replicate(),), # experts_w1: torch.Tensor, + (Replicate(),), # experts_w3: torch.Tensor, + (Replicate(),), # experts_w2: torch.Tensor, + (Replicate(),), # shared_w1: torch.Tensor, + (Replicate(),), # shared_w3: torch.Tensor, + (Replicate(),), # shared_w2: torch.Tensor, + None, # functional_router_forward: Callable, + None, # reorderer: TokenReorderer, + None, # top_k + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=world_mesh, + ) + + for block in model.layers.children(): + if not block.moe_enabled: + continue + block.moe.forward = types.MethodType(moe_forward, block.moe) + monkey_patch_checks(block.moe) + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +# Run workflow with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # apply local_map to MoE + monkey_patch_local_map_moe(model, world_mesh) + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/auto_parallel/job_config.py b/torchtitan/experiments/auto_parallel/job_config.py new file mode 100644 index 0000000000..c880cadb31 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/job_config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +""" +Use --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config +""" + + +@dataclass +class Experimental: + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" + + autop_force_bf16: bool = False + + +@dataclass +class JobConfig: + experimental: Experimental = field(default_factory=Experimental) diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/auto_parallel/llama3/__init__.py new file mode 100644 index 0000000000..ea38ac631a --- /dev/null +++ b/torchtitan/experiments/auto_parallel/llama3/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.llama3 import llama3_args, Transformer +from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_llama import parallelize_llama + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py new file mode 100644 index 0000000000..1d2bee4351 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + + # bail out + # model = model_fn() + # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() + + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod From 607c70d96e89cfdfe2d8f00ac11ca477f860bfb5 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Tue, 25 Nov 2025 11:06:50 -0500 Subject: [PATCH 036/127] skip varlen integration test on rocm (#2085) as title since varlen attention is not supported on rocm --- tests/integration_tests/features.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 6fafa29871..617a3eaa9c 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -357,6 +357,7 @@ def build_features_test_list() -> list[OverrideDefinitions]: "FSDP+VARLEN_ATTN", "fsdp+varlen_attn", ngpu=4, + skip_rocm_test=True, ), OverrideDefinitions( [ From d0393b37a6de672f7474034b12c038a1bef0e68c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 25 Nov 2025 10:18:39 -0800 Subject: [PATCH 037/127] [Local Tensor] Replace dry_run.py with fake mode implementation (#2057) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2057 Replaces `dry_run.py` implementation with fake PG mode for DRY_RUN configuration validation. This PR also adds support of Local tensor mode to provide deeper validation coverage. **Note:** Currently returns early before `init_weights()` if using local tensor mode due to some limitation of local tensor, which will be fixed by https://github.com/pytorch/pytorch/pull/166540 . --- run_train.sh | 14 +-- scripts/dry_run.py | 159 -------------------------------- torchtitan/config/job_config.py | 16 ++++ torchtitan/distributed/utils.py | 43 ++++++++- torchtitan/train.py | 11 ++- 5 files changed, 74 insertions(+), 169 deletions(-) delete mode 100644 scripts/dry_run.py diff --git a/run_train.sh b/run_train.sh index 83319816fe..87558a782d 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,19 +10,21 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -# DRY_RUN=1 ./run_train.sh # for config validation without GPU +# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU +# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} -DRY_RUN=${DRY_RUN:-0} +# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training +COMM_MODE=${COMM_MODE:-""} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -if [ "$DRY_RUN" = "1" ]; then - # Dry run mode: validate configuration without GPU/distributed setup - echo "Running in DRY RUN mode - configuration validation only" - python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" +if [ -n "$COMM_MODE" ]; then + # Communication mode specified: validate configuration or run in debug mode + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.mode=${COMM_MODE} --training.steps=1 else # Normal training with torchrun PYTORCH_ALLOC_CONF="expandable_segments:True" \ diff --git a/scripts/dry_run.py b/scripts/dry_run.py deleted file mode 100644 index fa8e1b4c17..0000000000 --- a/scripts/dry_run.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Dry run trainer for fast configuration validation without GPU/distributed setup. - -This module provides a lightweight trainer that validates job configurations, -model architecture, and dataloader setup without requiring GPU resources or -distributed initialization. Useful for rapid iteration on configuration files -and CI/CD validation pipelines. -""" - -import os -import sys - -# Add parent directory to path to import torchtitan -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import torch - -import torchtitan.protocols.train_spec as train_spec_module -from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.tools import utils -from torchtitan.tools.logging import logger -from torchtitan.train import main, Trainer - - -class DryRunTrainer(Trainer): - """ - A lightweight trainer that validates configurations without GPU allocation. - - This trainer performs comprehensive validation of the training configuration - without allocating GPU resources or initializing distributed setup. It validates: - - - Configuration file parsing and structure - - Model architecture (constructed on meta device) - - Tokenizer initialization - - Dataloader configuration - - Parallelism settings - - Model converters (if specified) - - Unlike the regular Trainer, this does not: - - Allocate GPU memory - - Initialize distributed process groups - - Create optimizers or learning rate schedulers - - Set up checkpointing or metrics - - Run any actual training - - Args: - job_config: JobConfig containing all training configuration parameters - - Note: - Validation completes immediately after initialization. No training loop is executed. - All operations use CPU and meta devices for zero-cost validation. - """ - - def __init__(self, job_config: JobConfig): - torch._C._log_api_usage_once("torchtitan.dry_run") - - self.job_config = job_config - - logger.info(f"Starting job: {job_config.job.description}") - logger.info("DRY RUN MODE - Configuration validation only") - - # Use CPU device (no GPU required) - self.device = torch.device("cpu") - - # Log and validate config - job_config.maybe_log() - logger.info("Configuration parsed successfully") - - # Get train spec - self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - logger.info(f"Train spec loaded for model: {job_config.model.name}") - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) - if self.tokenizer: - logger.info("Tokenizer built successfully") - - # Validate model configuration - model_args = self.train_spec.model_args[job_config.model.flavor] - model_args.update_from_config(job_config) - self.model_args = model_args - - logger.info( - f"Model args validated: {job_config.model.name} {job_config.model.flavor}" - ) - - # Build model on meta device (validates architecture without memory allocation) - logger.info("Validating model architecture...") - with ( - torch.device("meta"), - utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), - ): - model = self.train_spec.model_cls(model_args) - - # Calculate and log model size - model_param_count, _ = model_args.get_nparams_and_flops( - model, job_config.training.seq_len - ) - logger.info( - f"Model architecture validated: {job_config.model.name} " - f"with {model_param_count:,} parameters" - ) - - # Validate dataloader configuration (build with minimal params) - logger.info("Validating dataloader configuration...") - try: - # Use dp_world_size=1 and dp_rank=0 for dry run - dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=1, - dp_rank=0, - tokenizer=self.tokenizer, - job_config=job_config, - ) - logger.info("Dataloader configuration validated successfully") - except Exception as e: - logger.warning(f"Dataloader validation encountered issue: {e}") - logger.info( - "Note: Some dataloader issues may only appear with actual data paths" - ) - - # Validate model converters if specified - if job_config.model.converters: - logger.info(f"Model converters specified: {job_config.model.converters}") - - # Validate parallelism configuration - parallelism_config = job_config.parallelism - logger.info( - f"Parallelism config: " - f"DP-shard={parallelism_config.data_parallel_shard_degree}, " - f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, " - f"TP={parallelism_config.tensor_parallel_degree}, " - f"PP={parallelism_config.pipeline_parallel_degree}, " - f"CP={parallelism_config.context_parallel_degree}" - ) - - # Summary - logger.info("=" * 80) - logger.info("DRY RUN VALIDATION COMPLETE") - logger.info("=" * 80) - logger.info("All configurations validated successfully!") - logger.info("Configuration is ready for training execution.") - logger.info("=" * 80) - - def train(self): - return - - -if __name__ == "__main__": - main(DryRunTrainer) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..218f484593 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -791,6 +791,22 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" + mode: Literal["default", "fake_backend", "local_tensor"] = "default" + """ + Communication mode for distributed training. + + Options: + - "default": Normal distributed training with real communication + - "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU) + - "local_tensor": Local tensor mode for debugging purposes. There will be only one process + regardless of the number of GPUs. LocalTensor will simulate the computation by running one + rank after another. While the performance will be slow, the numerics should be the same. + This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D + parallelisms within a single node to reduce the combinations we need to use in integration tests. + + NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally. + """ + @dataclass class MemoryEstimation: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index b209ddfd68..6a73ffd083 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -258,12 +258,51 @@ def maybe_enable_amp( ) +def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): + """Initialize fake backend + + Args: + world_size: The number of GPUs to simulate + comm_mode: Communication mode ("fake_backend" or "local_tensor") + + Returns: + The world size + """ + torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) + + # If local_tensor mode is enabled, initialize LocalTensorMode context + if comm_mode == "local_tensor": + from torch.distributed import _local_tensor + + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + + def init_distributed( comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] | None = None, -): +) -> int: + if comm_config.mode in ("fake_backend", "local_tensor"): + ngpu_str = os.environ.get("NGPU") + if ngpu_str is None: + raise ValueError( + f"NGPU environment variable must be set when using comm_mode={comm_config.mode}" + ) + try: + world_size = int(ngpu_str) + except ValueError as e: + raise ValueError( + f"NGPU environment variable must be a valid integer, got: {ngpu_str}" + ) from e + init_fake_mode(world_size, comm_config.mode) + return world_size + def _warn_overwrite_env(env, val): if env in os.environ: logger.warning( @@ -309,6 +348,8 @@ def _get_distributed_backend(enable_cpu_backend): _ranks=ranks if ranks is not None else [], ) + return torch.distributed.get_world_size() + def set_pg_timeouts(timeout, world_mesh): """ diff --git a/torchtitan/train.py b/torchtitan/train.py index 1b4096abb6..03809dd0e2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -360,15 +360,13 @@ def __init__(self, job_config: JobConfig): def init_distributed(self) -> ParallelDims: job_config = self.job_config - dist_utils.init_distributed( + world_size = dist_utils.init_distributed( job_config.comm, enable_cpu_backend=job_config.training.enable_cpu_offload, base_folder=job_config.job.dump_folder, ) - world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism - return ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, @@ -733,6 +731,13 @@ def main(trainer_class: type[Trainer]) -> None: try: trainer = trainer_class(config) + # TODO(local_tensor): Remove this special case once LocalTensor supports + # init_weights() and foreach_allgather. In local tensor mode, skip + # training/checkpointing as the # model is not fully initialized + if config.comm.mode == "local_tensor": + logger.info("Local tensor mode enabled - skipping training execution") + return + if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1 From 1b9cfda0c981b36661bf7ba70b4581676c8f6a76 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Tue, 25 Nov 2025 14:01:06 -0500 Subject: [PATCH 038/127] add varlen attention for qwen 3 (#2084) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As title **Testing** Screenshot 2025-11-24 at 4 30 53 PM performance and loss on par --- torchtitan/models/qwen3/infra/parallelize.py | 1 + torchtitan/models/qwen3/model/model.py | 49 ++++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 74254081b6..268e7d31b4 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -43,6 +43,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, } diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 89296ed98d..fa8fd454b1 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -15,10 +15,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, + create_varlen_metadata_for_document, FlexAttentionWrapper, get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.models.moe import MoE from torchtitan.protocols.model import AttentionMasksType @@ -170,8 +173,12 @@ def __init__(self, model_args: Qwen3ModelArgs): match self.attn_type: case "flex": self.inner_attention = FlexAttentionWrapper() - case _: + case "varlen": + self.inner_attention = VarlenAttentionWrapper() + case "sdpa": self.inner_attention = ScaledDotProductAttentionWrapper() + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -231,9 +238,20 @@ def forward( case "flex": assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - case _: + case "varlen": + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + ) + case "sdpa": assert attention_masks is None output = self.inner_attention(xq, xk, xv) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.transpose( 1, 2 @@ -447,7 +465,7 @@ def _precompute_rope_cache(self) -> torch.Tensor: self.model_args.rope_theta, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, @@ -468,6 +486,31 @@ def get_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer, extra_inputs + ) + case "varlen": + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id + ) + case _: + raise NotImplementedError( + "Only varlen and flex attn masks are supported" + ) + def forward( self, tokens: torch.Tensor, From cbdb311319f58de44c31161b4532bdc472faebc3 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 25 Nov 2025 15:26:27 -0800 Subject: [PATCH 039/127] [FLUX] Add FLUX inference test in CI (#1969) --- .../integration_test_8gpu_models.yaml | 1 + tests/integration_tests/flux.py | 25 ++++++++----------- torchtitan/models/flux/inference/infer.py | 7 +++++- .../flux/train_configs/debug_model.toml | 2 ++ 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_models.yaml b/.github/workflows/integration_test_8gpu_models.yaml index 129049b8f6..b673da5adf 100644 --- a/.github/workflows/integration_test_8gpu_models.yaml +++ b/.github/workflows/integration_test_8gpu_models.yaml @@ -54,3 +54,4 @@ jobs: python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8 python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8 rm -rf artifacts-to-be-uploaded/*/checkpoint + rm -rf artifacts-to-be-uploaded/flux/*/inference_results/ diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 321ac1280c..a7ed51832f 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", - ] - ], - "HSDP+CP", - "hsdp+cp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ "--validation.enable", - ] + "--validation.steps 5", + "--checkpoint.enable", + ], + [], ], - "Flux Validation Test", - "validation", + "HSDP+CP+Validation+Inference", + "hsdp+cp+validation+inference", + ngpu=8, ), ] return integration_tests_flavors @@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir t5_encoder_version_arg = ( "--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/" ) - tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer" + hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) @@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: + if test_name == "hsdp+cp+validation+inference" and idx == 1: # For flux generation, test using inference script cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " @@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd += " " + random_init_encoder_arg cmd += " " + clip_encoder_version_arg cmd += " " + t5_encoder_version_arg - cmd += " " + tokenzier_path_arg + cmd += " " + hf_assets_path_arg if override_arg: cmd += " " + " ".join(override_arg) diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 0c06a385ef..b89887ad51 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -28,6 +28,12 @@ def inference(config: JobConfig): original_prompts = open(config.inference.prompts_path).readlines() total_prompts = len(original_prompts) + if total_prompts < world_size: + raise ValueError( + f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). " + f"FSDP all-gather will hang if some ranks have no prompts to process." + ) + # Distribute prompts across processes using round-robin assignment prompts = original_prompts[global_rank::world_size] @@ -45,7 +51,6 @@ def inference(config: JobConfig): config.job.dump_folder, config.inference.save_img_folder, ) - # Create mapping from local indices to global prompt indices global_ids = list(range(global_rank, total_prompts, world_size)) diff --git a/torchtitan/models/flux/train_configs/debug_model.toml b/torchtitan/models/flux/train_configs/debug_model.toml index 47a033c546..b943925c1c 100644 --- a/torchtitan/models/flux/train_configs/debug_model.toml +++ b/torchtitan/models/flux/train_configs/debug_model.toml @@ -21,6 +21,7 @@ enable_wandb = false [model] name = "flux" flavor = "flux-debug" +hf_assets_path = "tests/assets/tokenizer" [optimizer] name = "AdamW" @@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 +context_parallel_degree = 1 [activation_checkpoint] mode = "full" From befb7aed3e9604dcb0d05404a1b22952557ccc9e Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:22:13 +0100 Subject: [PATCH 040/127] Improve logging by formatting the dict as JSON. (#2094) We use Slurm to run jobs, and i just noticed that job configs and model args were being logged on a single line by default, which makes the logs hard to read. This PR improves readability by formatting these dictionaries with `json.dumps` before logging, so the configs are formatted nicely and easier for humans to read. before: image after: image --- torchtitan/config/job_config.py | 4 +++- torchtitan/train.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 218f484593..612dc28101 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -961,7 +961,9 @@ def to_dict(self) -> dict[str, Any]: def maybe_log(self) -> None: if self.job.print_config: - logger.info(f"Running with configs: {self.to_dict()}") + logger.info( + f"Running with configs: {json.dumps(self.to_dict(), indent=2, ensure_ascii=False)}" + ) if self.job.save_config_file is not None: config_file = os.path.join(self.job.dump_folder, self.job.save_config_file) diff --git a/torchtitan/train.py b/torchtitan/train.py index 03809dd0e2..52863b0ba0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import dataclasses import importlib +import json import os import time from datetime import timedelta @@ -135,7 +137,8 @@ def __init__(self, job_config: JobConfig): self.model_args = model_args logger.info( - f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor}" + f"with {json.dumps(dataclasses.asdict(model_args), indent=2, ensure_ascii=False)}" ) with ( torch.device("meta"), From b39377f9fe33865fefb9bf64a33f6d74a598be87 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:24:09 +0100 Subject: [PATCH 041/127] add all SDPA backends to op_sac_save_list (#2095) As we discussed in https://github.com/pytorch/torchtitan/issues/2091, we should add all `scaled_dot_product_attention` backends to `op_sac_save_list`to avoid recomputing attention during backward. --- tests/unit_tests/test_activation_checkpoint.py | 3 +++ torchtitan/experiments/gpt_oss/infra/parallelize.py | 3 +++ torchtitan/experiments/simple_fsdp/llama3/parallelize.py | 3 +++ torchtitan/models/deepseek_v3/infra/parallelize.py | 3 +++ torchtitan/models/llama3/infra/parallelize.py | 3 +++ torchtitan/models/llama4/infra/parallelize.py | 3 +++ torchtitan/models/qwen3/infra/parallelize.py | 3 +++ 7 files changed, 21 insertions(+) diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index 202f7b1e48..e7596a66a2 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -19,6 +19,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 1070f58aad..232cba9ff7 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -37,6 +37,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index fb07ef617a..93f6370504 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -24,6 +24,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index a7e1ee0dc5..d6e7397645 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -34,6 +34,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b517e5c15f..d191939e75 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -35,6 +35,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 01ce9d543b..c14741069f 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -42,6 +42,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 268e7d31b4..b6e9e644ee 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -37,6 +37,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is From 53e949cdb01e5184917adbabcdeeb16a8125b05a Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Tue, 2 Dec 2025 16:06:16 -0500 Subject: [PATCH 042/127] modify save list for varlen attn (#2082) adding varlen attention ops to ac save list **testing** used DebugMode() to print out op list. verified that forward is not being recomputed in the backward step. ``` [rank0]:forward ops [rank0]:varlen_attn in forward: True ... [rank0]:varlen_attn recomputed in backward: False [rank0]:saved correctly ``` --- tests/integration_tests/features.py | 7 ++++--- tests/unit_tests/test_activation_checkpoint.py | 1 + torchtitan/experiments/simple_fsdp/llama3/parallelize.py | 1 + torchtitan/models/llama3/infra/parallelize.py | 1 + torchtitan/models/qwen3/infra/parallelize.py | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 617a3eaa9c..fe51ab7cf7 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -350,12 +350,13 @@ def build_features_test_list() -> list[OverrideDefinitions]: [ [ "--parallelism.data_parallel_shard_degree=4", - "--activation_checkpoint.mode='full'", + "--activation_checkpoint.mode=selective", + "--activation_checkpoint.selective_ac_option=op", "--model.flavor=debugmodel_varlen_attn", ] ], - "FSDP+VARLEN_ATTN", - "fsdp+varlen_attn", + "FSDP+VARLEN_ATTN + per op SAC", + "fsdp+varlen_attn+per_op_sac", ngpu=4, skip_rocm_test=True, ), diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index e7596a66a2..f309172173 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -28,6 +28,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, } diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 93f6370504..bd9c936b78 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -33,6 +33,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, } diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index d191939e75..1c381883b1 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -44,6 +44,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn.default, } diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index b6e9e644ee..c50783b582 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -46,7 +46,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, - torch.ops.torch_attn._varlen_attn, + torch.ops.torch_attn._varlen_attn.default, } From 571ce7cb96cb368762ebb5c248698bce475f84ec Mon Sep 17 00:00:00 2001 From: Zhiqiang Zang <19247626+CptGit@users.noreply.github.com> Date: Wed, 3 Dec 2025 10:57:02 -0800 Subject: [PATCH 043/127] Make sure log after distributed initialized. (#2102) There is a condition check in config logging for distributed initialization, so the config logging has to be happen after distributed has been initialized. Co-authored-by: Zhiqiang Zang --- torchtitan/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 52863b0ba0..c897ee3c8a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -86,11 +86,12 @@ def __init__(self, job_config: JobConfig): # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) - job_config.maybe_log() - # init distributed and build meshes self.parallel_dims = parallel_dims = self.init_distributed() + # Logging needs to happen after distributed initialized + job_config.maybe_log() + world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] From b3da1a2d84436016be650c3271d1a5f6380090a5 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 3 Dec 2025 13:54:50 -0800 Subject: [PATCH 044/127] [mxfp8] [docs] [BE] add MXFP8 usage documentation and benchmarks (#2096) Fixes #1998 --- README.md | 1 + assets/images/mxfp8_with_loss.png | Bin 0 -> 47012 bytes docs/mxfp8.md | 190 ++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 assets/images/mxfp8_with_loss.png create mode 100644 docs/mxfp8.md diff --git a/README.md b/README.md index 2763fea3fc..b68314f297 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ To accelerate contributions to and innovations around torchtitan, we host an [`e - [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning 5. `torch.compile` support 6. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md)) +7. [MXFP8 training for dense and MoE models](docs/mxfp8.md) on Blackwell GPUs. 7. DDP and HSDP 8. [TorchFT](https://github.com/pytorch/torchft) integration 9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) diff --git a/assets/images/mxfp8_with_loss.png b/assets/images/mxfp8_with_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..47e2967aed57357bfa6ef2ab8d830f8cba7973f7 GIT binary patch literal 47012 zcmdRVWm{WK)GkhO3KW;(?(Rj3dueg^;_gr!N^#c|FU8$mQY^Rz2=4CAN&CF-b)66A z51f3-&d%O5Yu4OLXRRnT6q06^9f~dahB6ep&zSTS_3s7rmx>>b_WxXl zE5o7vceB45gGig9NKK)mq2cYnS0sPPAxC4C=V2<)B^|ARl|&M{Iq-Bf+Qa^h=ml0s z)z`v3do2CM(6@V;hm3+b)2)-2tkGOKaZl7nhfE3w({o=&g`|pfKVyga08`K(73K zog~VC^#0$J%Nzf9FO-lEh2?)2Oc(SYq1dJiG}ZqxK@vR+0snudGD+(G@9GVw-T#vv z|9=fCP-S5B+)3y-qSRvYEL#Qwqx!{dtRBP!1F-UfRHpcT@=s4GmfQ z$1&aV4{fwi1kyq`WTWun_|zYu=h#K2fiZjK33PTMwdz7zH$Jtg!15Y3)}J5qD~f}6 zYB;$SQns>Jo#*ag#feH(qIzZ$2csqLI^9S_4KhQ*}<=T>PB9ycE$}iF~8wOhcChl`T*0%k?uQe*Lgo% z#pQFg*Lg_l1UhQP@L*M7Ozmn7X8c5Fb&#^k{=9wZO>JpbLujzF;(z-Y%im}dC2Ov+ zF<^HnK>TqDd9IY1BZD81n2|ps9Jig#ey0j$hWY=~YV3p!X zDAV=9WNlY?W|u7#n;v)Q9|~$o&iBJFDh=8e)Vq?}mq{%_*-?pdK$zot*vmCs#ajW-K zt8<3G_r4#^6-6t8Dn3%6$>yLk&oLOdC68NLSy^0A5VAdzhO0=UUFX4CY{TU`=$4sz zGda65P#A@)bN|3JJgiVkXwb9#X$KoMf_r{tWkq=`JZ7O$+5K8z@XP1CNuCq?LJwCm zK|w*KEWVWXWqtI%_xN#j{#K&%0b7})?sf~!2?lM%P?A6h3MHVdXp>CsV)=)jPMh?Z z)8(O9vTHnid@kQPrBO|(MGm|?wQnaUq^}N6~MzUnWN1SH0e0egBT?ebchrvP*EAXcRwQ8II&|w?WW+7JTWwF+_h4fIOw< zZ%`2vd$?}$cY=cJK*J)=DvDhmq!Ey(8`vAZlcMVC_%uU96?JvWQ}6kW1hkpZB|fOt`GiPJlI8UpQE4(TGD67j!6NK2bNTA<#7AOl%gD;K zGxKTpJmfSmEi*I5@3c6p#xH=XL$v+yV&wkLodVH#b1(*v08d><2c^3&FfiHZ>B8{$ z@81iJ{>TZdZ_*4eFjaE+`44?MY+rY|$1k3<0wa$Fpc}FJ+R1#kZ1pONn~n%a>F~17 zTHnxk{%|nsx0P9T-@7X&YobmGxtsG`(ASDM6gpZ2|MBd+#$ao^!om9<5+kGtx%3|! z^SckPnj+aRs_1~qvFGa+P>E1Hx9#LbFy(WK-{qmv*c0Zm|2X?Sz~~7*W<*)MJ!LGb zNlh(Yb!BG^k09*=V)XI^3m@Hy>eYb*ZaiH>ZdxAx&A;01i1>?KYd{(Diiu0F^s}uk zBFnKQtd&&{)6;)NcvQ3blE1T)NXUD)k&n|!Jz=uPK*AWF&>5j9X7Bv@!8a;39F1rJd|BC3m zd&}VofZO5%FTj0#yk)adFN155go$UyiD=&a{P7toOHsF-k1HZ~-Tn))eNtHBqXPu& z+i_T8A2(K?#NErfwv(*` zEW?i`>ag?vGlyexSlkzn&piB1XZP$Y-Mw`LXOEAkn}d#iIgpo=+YEh-dG{z_L3Wop zXA{H&jX&_U?P){nwQu-y8BFumM$yyOjo$s{DaG1KT1OJ{>m5BQu!oCvtvli41NR#w zIDOw^+2-AoHm^O^Kx0iEqQHksBNPNU76Tv3aj{nsR@Mp_k(aGcN?Ae~P@R_R??9v~ zcFp6t?7{2_xi=e(qZm&dh~^5~4BA&(a~+8b=nq%U`nPM(x6j{KwL?!=+_+(~4*4$A zx}IX#{Z3q$eaFY|=2ea2Y%0mx|*}%j67MxM|o&N1)}7j-qJrf zI0d-&$$0Sdxr{|r6}S4360jMLZhYLW_8TZ$XdFY_WnUl5B<_5PaSSSBq~JyE!6VJipRIrX{>{XlESnbz-1OT=dmWl# z{PgAyzPU1ssEQ(b!27ARq~*e2i^Fl1QS-|eF7GLtWR1}b(dgoi$Ge+bhy2dYT$If+ z?wHGKUf$7kPO>7r%1!Jy4?|-^J-e6dt;F-r3=}K3c;f!c;V~r6WNA>u8vG#98`y^K z+d++)_xW}j_v2Dy-+*G8#%G-&?`V8Z?+qF#*+|`CyX+=r$Cf9}n{0FQcE*=GV{W%s zq!-u>6}6ZvywgFucfToTqdi!LE5jUW+*owIVsCXA zKXMR4Z3l7}MAeZ;LBfPG=)>C2p&4CU^l2KLsM;0>w6sn$#_1Aev*9iI_Q#yG^QT&Z zD5!BL{bhVG!?=7u4Y-(qkXz1;X2591mYVyy^ENGekf@5?^VC_5Ny_I!9TFO}0xvuTY$lyqX?bJ}6cejESP)Y96D!7qLIF-h3{0MFxe zwNJ+?x91x0gaCNh0=8Ujcw%ElaH$?=WoB|=*iCbAtIkZASLdw7(Km9u0N6-laC0#oE@hBF%up$3`6b!Sw? z9|?Zk0`~t{03EL%)Q9-P{rS%4dcqh&KV$^C=3jf={AX*%-d*9E9>m$?>7;> zPkRykmsyYGYuB@7g7&NWx!x~`Y`(iiFy(=D;hY7`f5yhf%7=_I=$OBBstAzwI_nR{ zR&0!NgN!%&QIygJ5?F6vGP}gC0^4szPj~c;9#K7)-TB!&+q0Ki$Fl2;VdwRqDPeqC z+c=e@0k0c|c3{iJX!i5Pz(<}s%d!3%G6e;NySXMu)05>zJ8)x6>O3ZJuZ6t@|`{4o>?=7s4q6Bjmy>!CfoG}F5nsb=zWl3 zb>H{g1+wE$mMG*O7SHud93JLue>!lnRLKJ6hoRN@;aAQ*UqEp`P|q?f8$%_|A1D5@ zubRzeg=y5W4t1E9=I_^dDQMTQrk3UgbDALEA*WW0XzXVlSq|&%1JMihCb&uTO6GGv zGTN^$MwVYT0k!5sE#*3Gv4Y%IoHo7v1^0*0vh9kwe5k-U1cFSdBdL2;r z^9^Oc+}!e1Z{4^ka4RZF;JE&ivq9RtR>jU3Kcuj^IWo+Oy|CR6uW_^5AMEj^;~c5N z-d+hH6gIG%6Je8Sx2!or;#i(nR~D^PXX$j(8%Buim3iFCRtC|Qq9WfduLJe-D5tUq zzO!HeWs6hvad@o{q1|;Q?H?U&&AzMp8{8`>Dx{0Qq7Q0mNijvNLvOx)gV4Ou zahg|Cli+SU@48=R+0k|6;qAV>7cQMXUpCbZM)h2$L2~!s)V&2y+CP2S2XTo<${)m6suhS z*lE=)Ak>MKroq_`BQVPU^(+1oWDsf(F88vh&Em7qFE9U5N09wdM~4VTuhF5Xv^47J z3GPdUp|Xt)V>C!Ct2e(_QzG^f6D44wt3lOj68g_7Q8J z)TVJ0D)x{EkY_2lH3m8Gkwgvy#rHKrd=#o~*?IW*YJE@1pTq)pxK)fK%h! zoe#H4P)`79$_$#jbadVCx`bL=%M~GAJum|slj#YgD_2=E0DWqHabc!E9D@X^0_1V& zIZY(5&nq*b`8=kr)_)_GVbrfwdx( zp>_EG<~B7w_LdPR=ARN+!d$ULo)!HI`)hyAIflSm-pX z;mhWv^i#9*x)eV&e*Q2c*A%E6^gtf6hw(|otWa;O5`qs#Eo%QJ8H@@IMPB}aDtlT6^duQEG;)xA5h{#f^ zu=fA?e(O!;dosoLrI``jHldnNu;Ez;Vb3?V9w9u` zg2upqKl&wc?i<|zKt z_l>u1KRyt)m4!)fdN>(e>|{r0Xer%hv8($G(R27oEEztasi@L_32=to6;S@R5%&4B zXKK6_zr&69Xv!z90}0*ijVQsxy&)3Og+%2UOLhVl2Pi?LDN?@65+y#uxeJRu<-PB< zR9Puah3{xc^xja$GXG7qZ#19hE*whZc|N=D$(0a%Fpt|;N1`z!z0AcqlD-nn zt~^z`?`+`p;hKSLt=ul7@i2|*B5n|UaY}V3?{G(!?%k<(K`H1Gm*d^$1>Dx1QZWIN zx4jX2s>wu>TckjyI%SBFoTP)j~#&u zGlr|mM{+G?Cdy1Vn>axw);Am_400bA^j3IH+nO^+6{-%wruFosD$HZsP^eA4rN{@$ z{v}${=ikoAH_SX!j^3wHED|N76l2v7o<*pqA}$S72ODCB%&pof6k{EG)iK-9@6&!Q zd3cw+t!m#S5fm#MWvBCTpB8}d^ZNdkwzVvEV!_6cs(I-(BN@VRH(Ea#;W|qc9$Y;x zTvB16&l4UTS{5a2NEe=mM=KGn-xX!Sz#^rdb7&TzgIV`o>P|TXDtuu<6K_{!q3WTJ z{;z(??#{NxT2lF@Auvh4Y>u8BgnKMQL?@c{WopK_(M=O~WL=Er6VT-laHk0pwE4ljZg?7j{wv%vsdQ2rLRx z_+#oMMhlX#B;W#(_JXA273aZFR<<%Wa|ii!$ZUhl$(}w=Y%vSXt0JzWk#6xvTwwYf zaZz$pzyx-e;4Tr_#s6{V( z%zfZ_jx!DRc6U~nVvwvb-;z7*iYMFg%>Yg&rB#R!@v>W!&uOM-;7WczIR z^jX?Z3DQ{j=T1_UHevr@lV=t!SbFxe=$ zAec<-lHhyJ$GZ^ulUiO}AQ=ML#Mdrlw_LmB*(3ygp!z%J`In&pFQFM1&nktN`?bEz;izmkuOEri^*FUD21HB?LkNkf=6H+y8I)R)S9%GVS zvm>ul#;wf(Qep)Ofw8#9+(MHFzV(&&(q&bCGJV1kD!iRYzT@sI&nN8PY%)XqF#Th? zU#{g;2g;{X4zeul*B-casByCF7|BnQ#TLRx*U{Sl9379Y9Zf`N){XKr&lMVY|?+3 zHOfSbH28ys3-?^AbMPFdiv8S(Di2SZ4)~~T7&J3$pSGIpdL5Vtm`Hr}51#b%%2{hLq9iALh%gS!a1)oA)8 zvPc{@^PUg$XRhg0m~7uC;$% z1q$oM$u|6bh%)w%n*1WN_zB@>xyYskXEc_ZQ#d7m)egR-w;638{sAIX3hg}-nurEu zGVu$miWfCoj0CJT0rMO!H^JZkH$vUcvcJ#VYJSKo1WRTk$SkWd)DNQIZ!jfETkREB z{%TGzY-J=dqTu81auWYq+1}|XgAZie(%7lv3L9m)-(EVzz{+0-kYBA5Broi!iM@7 zv1QIkGr&ooE}5%52VqU0$l6;lyFqN(2Nb6gv z=8mhI{)(ezl^fB`d$8&kYKK-GeFA#zZ!&og7-5Ni|B42(n~^Ysrl_l%de!_@eW&fY zwR|!5uK17~6v?tWj^!9;yRv&!iaxz^BG=zcKU$@o9>C5g&1RBnYq>_IfEvP6)0Hb# z-`0)>z53dM{ErRj(-W{#MjXyO z!D#;JM#xjdu3kP?^jX5D8I2iv#`dX%=m5Q44H^jUVk!?^5HLs%5iv zc~o7bz9HxFK$jC^AYZ-3&)HB=_2g*FibRj&UG^884cn5>URa4(9A`8g9JR+^9rVC| z=1yqd0=IELMYoi0jDx}V`}xlG()JGe4H4xTs>!twvGbHqR#*YrRwu|3fDo7qh z)6_%(q*jB@o?}*%ov~~ ziYp;*Z7y%+z*UQPn6^z%rdazeKd^O=&gT1-BnKs; zYu}ECXVT zSkmW5syhtoX}yF_D|^nYDDMCj5XYZh-7s2~#7Zzyt&pG?qQrw9Y8Cs-b;Q7Z4i>HK}xV&Am zKBqsW;a{1NGn{~S{Lk7OVO>Fdc?^4$h?`JlutcX24&N9G@-2-0@ASE9(QAd2EA20o@>;s~+4|mX&6v9Q(~+S>gcX@|3VHGb zu1pRkwY4&!J>2aN;uBNLi)kBtE6Xcp9kwUAah^=~yYNTK65i}5tLk{TI$?vl#Y`@K z&p1<;&vcg~E)1yYEaXWet&I%cVoC9;>Ud|tsKD8NG6H^HT+M#xMKVRT6q2=>M2;@- zvW!))1ldomeRGa-auyTb#mDPFQR4_>wq{HjuJC`5i!!Z_C@JMI zjenkbD=MPglm#MA>4V?>LA%~c_zNB@OftzVLunKX#u%Mz2{P*h?yY&MY)t90KV1tR z7NLqADbIgDufn`eJ#N2SM@xB9Wou!o9#5iDu{K=Ncpha_B(gH0oNIA0HqGn>nFcrK zFCl+}@GxLuJ+lb=v%)H#543D&s}t{l9gnsjGAz;+CUczL=phy7ao|B!U1z5YKOkEp zalhkF9s1zTgP^}A=B#&s&?ctkZ~86fk}qfM>@&EvcqHI_O#~Kl$31f?t?by)5oCIT z0%z|6-r>)`^p-`Jk<@JT{)VCP!Y@jDZlK@nR4rz0WnpH z;y>6)DJCcO#l^dmqIa(vFhR!O|Ip)1>T$F*b`mZ+nj@MN-0KHJoRz85aig3lpXKdp zsUccwtoG~+6)k<MA#(g^lDBkwLYu7e1QCKL*EC< z8{L3{6)f2=-CWnv3z^V}-NZB(@R^mfle;T8H&jt;sz=%12@1~P0^v66D6r<%ofmKc zj*UdGd;SYg{2@OhI^mD8ODoGlA)O4!!(5ZY=;|6G2Ajc(C?3H(%SMdll9H0Nr7^Jf z&9JRf7FSSjFWg2H0<%^{QGI=S@RRQDz&aLHF|JE4xqm(#Gh^QHQB`j4w(%QFG(vinOD(|!J7^g2z!{NZd`wnXEJKErQCn$y! zJE-R5%+w<^?#R(ZGdk1ze$(aIt)KM8{tHnODxBsv5(7%Q-9hAEIZYzW~U)R&A76DXHqEejVB%lye|ObxNi`&T}~DSVy7J=XVhASeD@Bb z)QA)ky{;WOu1VpNR8#&d=cHchZjTB$U5|+v29-|eS`}xw8>Gd8*r3VtSjCmPq7pAr z(hv+5!NeXqS&f^Y*F=Sj0!}HH)T}_!Uc-o{zO^iBPz?uq-$>1b|0O5M-(%4~2xK`1 zqNBO%7-7VO*@f__=2M`jO$~+hW=sY-glLW<1UG(MRiX4H2{K9Jf1#jrIhnOb$;;59 zITvZ|{O_}y^Trl7>*IF!6bea?5U?^UT#1Hw3CX(up~97@7iblUsX^Bh!l`>WGpZ`Z zUsg8)S`{aIf5bl()+tB%ECzod?4)PJ-oIAnYc((4mhI+W>+4Zlf~5iwH=r+iFhbS= z<{X8Kt{in!22~!K5+aq&Rnp_Da<~pAOZbXTRr&f-FElotMQX$=`K){u?RbltOqHNC zHH}McE&wf@G>MbI`x^)fEMc-&%k8RPL4dy4Ws|yj4X(>zfkk`X&G+J1ji+irEK(o$ z>e{+nRmasTGWl2FYs4o&r&)NVl!su5s^k`iJ^1tq{=n@b0_T~`**M{(bDCH`C$^V2 z+(fO(9hvz{<@Rr8wmxTlbSaF7JQ%#<#Lnr~R-e6tlcK`HMvVdJzISvaw#Bt^epgf4jxGDaR{1@)z(sd&)|W3PnMirYZ~D$sKQ607 zpv6X2TT-WO*;7+_R*Z7Drsi?7qXEaw#QaTfsPQF)F$sS=S~&cRx;FQWMe(l`mv6#j zPARm#s6~0!NR5xHdZY^`lYAu*oFyl%p8{Ujp^7w)YjVC>wRE6Z;zfjHwAy17gN;wcnJ9 zsjdCGelcrP;VA`sE(Cq_QE!J@188-DxnFA7u>iDm@%`HF!(pXNfjpfZ4>7=E{n8t) zkkqer@x>xz33fY8LhbL&gq*SU5`dNF(JRYBV!uKKEWanzX zK6IONsehaUC<^g5ACC`f8ug#IynC}%#b(4T8{jL zS|2^V8d>{9D2Qm-7r5r0NjF^W)TYuwz8_9zdcC<$kEUm-5UfTJq>iTRge0xvQUEVD6U!Jc^` zhy?pcLYWIoR`MsSk?UA*v)k}ouMnJSeU#DR?Xt9w6=_JmcGuJl)4$bqLL4GH-1om% z)gE}aCKQU_PB5bN(QT!UhJ^2@AFK19QTFGm`5hjm z^56TRDlHStSZc7Rc`dbm*5I4qQe(m#rJqTh$>M8>{56K`>E~=e?v?^gQ&j0Vw$YQh zDC0ve?k7BO$l`)E`sc)?TvW-d)4FF65iD~;RQgpH8o)9Rwg!yZ*;JaI%rfl+l7hfC znZIDhmQAs|6D3@V4YH?b=lsok_q19K2$G1rkR&bhgXIJPojOcE8 zY_Y=*!hwPw72cy1N%~D1PkRJBj}ev8V0CSaS!pe?@4P-#oxh2o`{eAq^ado7c9%Ne zL5{_v;>j&zNTec5J+!Lg#uYx@id_6 z%vSDzV@StWD}+VeHBgz%Lp<$yv$avr^V9ET=7%2UO|J5l$gN)a_aP|Iu-m-pg;u59 z#y)AMOmnzI%guDee&&3~U4>WD1V1|mQwbGlCG%@>j^6vwU87ho|JD4kSgCSLBAN~D z0{XY!_krhe!ieY}uV!ES+P&~(IMV(w2g3Hr<~s-m5$)#ph8cGM`aqA{KPMW(r8GMA z)wGOUiU>c*=8|({5VuMve3*HF-=R@!2ZdPkgAw{&DmXB+5tcnUNL?-b`?axj8}g$3 zzr>Ec0(>!6Q(Mj!=vqJ->_mKs`7~P}WBx3(tbMws`~lIy6RqZY+_dRL@;)R!GPo?N z;vTxLXx^`IDq3j69Fg-Panp-GK64_Llo2UKFASzRXCBrJDk-KEaG)vFFwsf9&pxOj z6d)aDUOu*=?z9o~nC7c~^3P#f#5i{2Y9yn-xB#5RSc z4)DMij28R$maUtGMAmB^dBQrou|xY^(i>3Ax7?C|Si!GwFDREBicFk+VRwULgG?k5 zB!{#dm-;TC4%D)8V&B~?hCpm;?B0{FbcS{4)Fo%309beF8*MKVn7|1|?2S8i*MI#& z&OF|NUBuhU<$>DCa7PGpziJRvyI6QPXL{-cled9?nB~Xr>I1_xIRL?BUGnDluWups z*=N1C1xat+sATT4%l786$q14paHaP1JJsKUb)>NTvIc?5e7$k`5x`zuQLVI*bHt~{>|H^y#$2KnCC3;Bk zwC|^ZpIPw+fCGz}Jtxx$Xak-Xw z8|$seCRcePm*#sV^Lr!8x93ZFUV6;kasG;h4`WzVSX9`v7iAu)Wu7AA4KJqrP14 zpr_eyD5WVZejX#poD;gF>bq*&idzduc_3|>Gr+HaCER9l?@C624hx|=9H`V?cx85Q zG65)26=gjf=bx47YbjxgKYAAg(5rd;l3;iB)0K+Kjuv=Q8pehme)Gp(TyU`(_I`$L1jqFdL3Eb zPA{g)g$X*9v##eRKme=F0vRf5AE*?zu7Qy=<=qY|%3e!wKe8puCzHAx!hr8nUYO z`aQ#FV_QvHXwsvXWaIP_w@Obwe+bRV#nR~81)_7oO8-2&6n()jWAZzsjMN(hE*D=G zm~A2JBBwY=BP;J6>FZbHJ+Qp4-=lTE~*?-A7sgKx84#jGRFbhNMy zuCXRE&w`oNQsd5u(r^r(X@&;JLdOFrhAG~=4S$rL7m+sw68BnlA{YGR;Vnywj&--% z{$h?1G4-=v#_U@bP3s9j8?gW-H-MbDMHHw$@) z<0xKs@v*d)sEX6gM{ToWFR65)%t%MCgzkDos5%Qac`LN(VoH)~k`hbL@oblDWBOTz zzVSf)$CEl#ZfF`nVHm?yJdbeYQ#IDY64s!P+C`f%Z*pq!%6^eHHHEA#qn)XyN5{&x z=Q7aC8=q*DCn}r@p|k60%L=0a?hXCS=<-q|yk53x8Y>W|-}~L95a#(gE`f2Yo8UeO z#!j=da3xda%1u@KvM14fIi0U_DlZ&?mXZ~hl|>BKydxxw=?8RR9sck%iOt%6mLyIb zWQ96Ek(*YpB9lEOS+h%7*YZV;ZIxL%Kq^smg|=1mx>&R6X_BJf7ZY<_dj|(X4~Hh{ zViTCHO}Nxm2r=R2a4ZxwN!Ipe2(A<70(X9EG?N}U)X9hTbu8=?@`D_`-Wgj?7WeG{ zR>+!Eo<`m0Jl@ull@VSJB{#0z%njasf6tswUTPlWVseqchVR?v1|1b)Ub|*MSQMAC z8&}W;G8fte<1|!z-OM#in9o%@734s5;IJHP{BFg8JwmTF{W^vx>b8%lml0+JzEUN$l*}LndR%&QwY&y8 zaq+-Vtgwp+b&uYx8)&NrlqsabrSIxvmh)#!olAMSjtbL0edxBuCXunUu-nGEu~)>X z(g;%PA}Vp?f-M&GQj_?NOq+2wmo>tA5pGiiD$1cDT*P>ud13O^pqM4-zYAHFbW`w_ z(3m=p9j}bb`GDy1P|;su_SO;6Zk~*TJ83Y+!>SnIPQ57ZS#+aldGW!GCGoo~o_&Rbwb)R>+NYuH^RmFc4u$d`&_QYw!tx|xs*PvI}kRLAM%YsWc6IZ8K0 zc7Ho98=J|NK6afz&0;FPevu{lW zV)cxjt$B()oa*OO%(K@vL4RXS%Ggxs=tY`4?hot7#-hZ?S@BkDf{@B7lpDjPzMj** zZr@iA-r5Msb!7=l`0&((e-Rb7&12k)z}X?S;4U~tf-s%}d@2@y zAfAk)FpU}Z8!|6i(SKT{DJpdi+Ss}<1nr7|_Nq-RPbsJfC^x(3xYbCr!kOG%Zc59klOim6TNXiruij9K9|yk9(m`xSo}{-WueucX29)A1L4Hi3|-L{m1QS5l?v!cG$` zXc`IHIM$y#C?m&utCkJ`)-&Rs-?Q*JE)H%D;Z`dp75k6bt>}*LsTT=Uj@&#LwC~nN zqaZzLeo3|>J`SIJMG4gMo$&LQ^{)qXp@Vn(g4G-g5tvp1+V z-z#Rz*_g|D90h?Um!_n-_9yu{c#KbLkAR=#xJ?~=eEhp}tS&{iwA~TKdFygj+~|*9 zryJfdxuPUL9gV+r_p}j!OwpIy{o|~8hS$#CKw7+V{mJY_{O5q%%={T5A|F3}qDF3O zBLt+UA7rckZ!N$UXe9!sAD5FYS6qsw_QuJC#EOMQL4tg60Y;k(Qdp(QmDxfp~u7|8_xwH&7osGrBNK>m14$?Kdp2B*bh?*LE zC4+CtIe0`AY!z*r=Ot}i$PM`N(CG(aVleatToX?_0uyk7Sh&#x&CZvY)a-+IYaN$( zsED2ISI3EtBP00uE`IL&sL-eV&Q=l2IY%vaTl}_>kg+_RAGVyuo~EtvqQ$T%uDQ7R zM?AYOF+0W|i3mvt5o6nKDNHKXK^@3tsq{K#0v<;kvOcZeuPbOXMbo3Bqvu1rkTcNz zv^f`s{>zOJ8k*qC)leGg0%Yi(oBr+rZ2RL!EG+h2v>8_Lw})^aP_7s#tgb{L;L3?s z%K73>9lCFSwtHI+?)vdODMWTwJVBFcMnEg006lLAKBVt#`)O;$QH}=3@Af&me}faI zv(+u^{n_97MfUwcd(qBN9(m2p1MjElGjcNu9F5a8nT8UikOf>11a?8$y}Bc3f6&;B zI(ljfB1)@NFM)C5C29;Vk0(-@S;T_wPYha>5cE2(M`1zF8=BK>@v#>WCd>m8Qp@WZ zium2CuUmc7`eT(AXT9^lXkQx(=4Bk znZ(ehC`>IUQ>G}vos|bhQb%BF7#RWgae-c|{F*{c*-B|$6%-@@niK`+gDp;tA)4Ls zTyrn;WOJ&XNh#6#2EL>r9u7V+3A*WS=tBFx?H@7bUSl}z8S}fB^4wC0ngnX?ez*4w znL_By6H_j(NigRMx@Ixk+&nt8E~?23LbFm$7$Hbc&o;O8h%JclwYIgVBvR$z6l~cz zJn<&q*Ze*=N0&dzU}!@U+UoXQnybc3)=I$tg3x^*FTxiyK2mmw05t_>LNKBs{Qe$M zWPI~(QmRlx9+WK<>l*7PHa^)G@vBPF_2K_8b(K+3{Nb7skj|w;P+B^rL=cef?j@wV zOF+6rT1r4bng!`tx?8$Cq`TqH^1t_->jym^W`6VL^X3c_9F)SlZ}2Yoy>DN;x-wq6 z*KqL&t9r%^4MitXTbMX~zga-Q$5-8O^1M8-&QoG>`q>yME{p>7jj0V$2Z~G}3`42N zKR8K*Mmhv~3fFr)uGR$nu6klqQ_^Wbku6ge&I1UM)rlW)%gQSZ>SVLS!oqqfEuxJ5 zw?$Zey(zvxPWmdkm$1V=(vXWfg@uG5pXOnP_!6z8ueYt!OEjTIV_vxYsdm^|fycVo z$3~0_@dDt2n4A>e$jK>UcYH1+(dwX-2T-57)l*s;DG3DyKnH*k^77G}W2^#YjDK6t z&oP36I{3F!47Gj&q7C+mISK#dvuxqx6w%#{ytP#+C-=4b5izKPd!it`D+lnL`*nMf zwc~8NHOMB}3k^||mXvb2t-r3WfbZb*OBp`lWLo@|1+E1r_6%>MtG7jTl;tb^(1=`I(dbd zZ((j}*RNKlYtLMp#setSsU86nbJsz6XalEf%)$d@#2A0~lzQ4%2!Ck!zKUD+G%@ELEtSB|63vn+PMDPs8iE-pvbtep6mSq*#+<8uSDzk z>B7F`>FWdbWe(d*v@OmjHl@j+idLM}f})aRU0F-zFOja3%V z85dWTkmqR>eY5ZR1@`IU72EiXk|84pbJY}pEURbbY|eK};*yfx{BGQXb9&8=p=7m( z-{oGzLLYz?N}9Df(gQy=4J3IH_syNDtQ0#81HmiPp-`OL4423{vd_nX5Zaf`{xepjsL3DIkLZ}>zl6?<;;d7#{CvnTb(4`VN{mN7TtuK&TM)ntk7ViWNtyx-* zz@_@r^*WFpO@s)|mw|BBMB8S5IBGIq4f2#2@S<7E)=|?t5!EMfpFh5CO4(8+XKe^J=IX?2=cCZiQXkyu6nlUsS)JS3KRvw&D3Cc!YD( z#FN1VOf)0AlJsh)rk(PJZX&v9ur##j&;D(%-HG_&;6%0>dLl1E5|5^@8O9>*orhlYNBNH}tTi1`Dd6p_DR3Bw;H=k@<0T9UJ zjFnDRR)qsdElz)gxT9}fDoKwh+@tAkc-HJ{p0yt=Ftb=aDL%_N?{`uY`NpA7GKIqb zhTOmI70c+34Y53bri{NcI0Nge|PKEjt=IYBPx~j%$+=ACZ zY|y6u1>&HqkjRVyTbt=3udL7u6+8FF4BG-%nvxS=t+&io4_u%zepcmJ{gjuwLlM9+ z`;Ohj>CS+&gZ=#A!5FhQH*ZJ+gM-P>(9tolF#n43>!q@5_zM~cJ~8>@>x+!6jG`JE zw)DprjI__oUY+@7k<}iRro3UK668zHySjnvGJ!bFTeUC#KBzn?y_mmG)Omcj?J<6r zeiELS*7SmBTkn$G;Do{YXA=dg@%zTi--HaMF?Ez@es)!iDoMIxe+`8SA^Ld5VVD)OzRpq>GP$hYV$Nm3fsSEwj<>x?!H6NLS22}jKWK?E?T4zXMG zeGkr$b1qN>5Yv&A^dAt+w~7a*;p%+1;T1DQ=(E_q_i-O*xR1nJe%|OlXjJ;a5EB#a ze9=w8;2Vj3Tw?!)t)*dbbQBfNwsc63-}{Ft6D}nRfp1n??NOXUL6V$E)$9u(3GBPj zT?KpK%F>X~gQ7X16|L(|Q@QWJDooAw&PT7FcA@wFKBqU+7|~H0bQR4bx7S;zJo)(+ z8oqDo3i635ec|)uWS&}aT8_^C)W+7^B!!F{5Kob!1=GOnC0f`0mw~DPF6n(Q?ETjRQcPW3qV!ltj(5=QmRfqM)7&)& zaT3bONJ-2-@4VP`sdSoO8D;OFWxKxoC-0SY@z#5&9~v92$jQ&$J|uWm=l!I-c?mqt zzNNW4TK7t3#L}v7kB{?R_X3*OD*#z>o^CXuo7hyo{bwy=fz++o2!{*atp!L=jc;h+ z4wndw?}o1D*nk(M+% zAP5rvN#yn#{hvx9VODH)k{YL&w=P%d+o8le{d%`Zz#SpwxzUDCTjD^DkxZ_cNu)3` zcruw+tb`_UiyXq<{mrtZ;L1~UmjbM~~NaDU-DTp2l%$Pv|s=F|DB zi$d;c`^C+!8GgW_vwS_nlac@+@&u}x#4*TA0P7ug`&}tT4>ms?ops5nN2@W3$*P9L z)PP9I_8?Q7TwIY~;c@Kk>(*LkI z@aRsxyUGaZ_f`q;;);$Ir;W^s#tI@KZX- zAt9-$Y3on-{(fE12Tka&=4L{x7Qm)a=5=)W-8R^zuqdcA6gWhotuvhjN9lqc0f!1P zkRY|oeXEZnOL>zPChoRviE_UwzLYq1| zGb^um>HH(ZoIHZE%dI#7a;NjVnfA>T8J8il@i=`}gnA$__eRem+oDnw@tMhhkG883F(5-@PLyCU3t&q@%Y$MSyiN(SB0rZMqCqmdkZCAd5t#LARH$>2D=7uYG z6$Szz&5{s`+4cl#17#LuZazN8!}FJ{nSw*BnGbqG{y#o*%VaBQRkr>lm;Uva@q4($ zn=g#9uj0a1Jt^U)R8m`$tq)79H=I3TwI zM8BjzFEkCjjL~>JKOuUp`bp;pSK{4kyqEv>+Npn%>s51G0KEVl(d$F+U*A8S*3+Dw zp_2AZgFG`|fzW$Mt3gCx+mv48WmgMyWCtzY@ArrTwlvPi-uXD} zOKhE*p7{n(+g1)s^f>(Edfo+YVPv7<;maQ1kqQH-Z}U{I#VhCzJIm=}&h|7I!2)Q9dz`rV|?#mi9 zrL0_^bR>eUDBg|A+Z&4Rs&U@CEnd~TKXLHcEdWV>%bT=}0>;16nwls;8zB2io9Zvu zD&G;%sfdbI934-1rOAL3E~)8zU}3Q>lM&y}aZbLo_83I3*X$Q*U%p6#iV*P;Xrfh2 z{MI^79bsh=R1_dPZSwzC&StpYU#%;&iz!>QRe44rBc{Vz`yQgC=O*;tF8e(zQKq0F z13MpIPfLv34=(wI`lpn~0AnDSezy5tpmqLI?>#4FF%mm-bl;iK8iYT~*EM_u+A!FD zkKo5a4FqSQ+rY6nqPp=GI#?WrN#Et>-J*Vm1Px{P)x(?Btp8U)l*|x$3QuJ6p|FOV>ZEdNZz8K-f@Y`ZV3*2IO2oY;0S_ExFkF zw@;lK1~ffnRrBFT&OFY4C3gs;BpvuGXX3{sgGA}q9KF06ULy3gjj>JVPMZc#j{f^YgAijL55 z>1u=_C#S$6z$4EYxAEScDAS^rwLLZy>&-HFKhHPRa;mx+hs_z#5PA>;5KNQ}K1y;}4 z>$Z&P?^bH-O+sGx7S37(s3Sxq4r^lt7F|g)ph!>NF-#!qZSGk?^9#WCd+ zwCZ~z8f5D->-q;AR)!s2@+Y9HB!=Z9^|pDLa{B`h4MD%^yXJiX4a2L-?YcS0>RDyt z;HQI|8kPzz=SFR6yyzgJ>=fPl?J@vtdKE6|d@VKR2yCkdtRJwvH>0PLvW%zk5`ls{FJd83Vb^|M&5*{GH*!UgOkBi5oQ9gJSxYf8HZV2xPcnDom%W$j%j@ zJfIEBoZnzr&uiZY+7$9(V#ukdO`bAFF)F$e z*z^w90##Ms`Ohr7_K!5|Je4Ow_^{_YZ2mSW?s-)yQsdq@%anAH+c3vGx?c=9UoZMBUc=6OaPu?3D(v{@N+vwzBqXscjSx-{!g z4L{xN7eND!M+b@cW*pHSS(fL^S(&U1-7?)W}daz1V@CW1D z3I_WPP;8HON|uF(={XO*K;qkz31b;1awE*n65}@Psm-cY2ahQ{{@P!Qsm%H_9C=E- zpq^S}XbOFUfH7fWT!|10nC;1a}#msgX1EicNkoy;yFIsaAcs z)(1mjk^Aa3N@tWYwc*ASv0vMhPer69G>q{G)c{Ls`Wp!|%4;mC&a<5ZieJs8snXmk zH#>>$yvy(;k*5?a#6f2dm=wt{Zf}{1B$|D4?4w@cMdO7dA&Fg=RxrPy6|;;E(Z$oN zlVLYDm@`wo{LWqhcmG9sTe_3RG%`fG9+PAXGl`HN)&UJAf&hE2ZwlB`YWL89TBGHN zEi%Sev`IwEI8Mb3XyWo&W7NT|F_v6}4Mw#LeS=l^yhB!o>8ytq#_Z<1kDyiZzJ$Pt zJZmndvx0w~@!EcqYdpmd@y)F`NgvcvaLTsmW$m;d&Q$s$NvedX85gfLNYJg5d)^ES z2#v7TH;p6wnjY4I>+%q;VaMts_ki$Gz$~aYpOct;m8fhI{)6qy#4bGIP-hWh_F)#r zTpiwmqKxI#f9+T5v{PKHxCh`k`GL2eN9G8=OV@*wj{?|*W>rB%Hu4k05LWB@5(No3Zy&N z89A~=y#X^BS`cXOf1C z=2aM&(Z^kUn*4qnN^dHz&(G{VICzIqjto{#C--2rQHH=L93)vza(g@*31C`8z`lxL z*F^~PK=~%_87&Q&O084VC_v*YEHOrXzlPK^EM_+fioTfUaj>M4Tj6$R8Z%j}OT3L5c^*q9Fq^S4&J43{-lW$y;nKPZX3QUJV4#;$K#Lkyd;w3t z8HZ19q*N0P{nsVE$*^el&gdG~&I)Nfx?Io&F#I^zvmym0YXS4TedH-4gsRtmzwiH9 zm9v)%Ssbhn9BvnJfu@Y2yF5U5pd%;;-F}|h?H~!b1JtvL$9x8|X0cIcD%0cQ7v_?s z$)v`OT~E(#X~BY9zG5!!EzWfTV^mdX`#`mBI(oOyw5(Io(+w&EN5Zs~vpMuxLkmy0dgWcKvCNnl&R2BsfapMEHIZO1tDAq3%U7o5TBk$J z=Vh6{P%RAFT$PB< zx#Iv|(blsjLKKp4t&(UBPEc?l3PE_6iMq0@N7E->95^~C&Z9CZn~gX#OgS%-Q1D{9 z0wYg#BF0^jn~H=PghmqStnE0yF_fn=OtxzduX?Uhg3KY+HY<9Y7B`T>m%N~D6*QJh zJQR;^qtZ|g%1jH0*!#|^M7ULT?_PELCawB>ZfMUJ9%M}{rhYOV=*bsrnZ$JGiOFfm zH{I4N#6W854cuGv*X>;^k#98A(L+~k5ztjX1re{E@OKmE#G&rG zd~bP#pPSbKW>M0**7yl39+TRUbMI+U$!o0g8jetmCDblF_@$JVLl{A=Y#{+%#T{>9 zhJ|@DK3-r3Wm}^E-OxUHj{%j(vWn5xv)Stx3*yS7?qI>>EOOWI>(Vm2avC;DFJ$3k z-yG4n2{JK&Y?%j_GPS!^ZTi>X`8GLr=QJ6t0d4owi#%qz5%Zhnr*i+5_F4Sr+K=+~g zON9Qk_0;hr^~Lm}QcGoj#tNQ1jjBnC{C)$$3I+>`CI2gOpve9BnzW7WHnD4AXo%$T z0eUwLjIPH=l??0s4?L`&K2QDda2TRE-UkiFi2FdjhGEpisIh)B4&LI8gsR>)3x(^x zJJO!QjmYyhYdSwsKY>^av^UeqX`!fF@>t7XkVa0=-)8F-nidx&dCa%wI2pbE@EVVj zGMYtLo|`ngknya)?QL^YO;7Y>?Y9v#M`zX1N<+tBn0pm!R%hrJ=Gmy{{?=4D?T452 zkn@B_zKN|^1G9<*q6uB*{4+x*4uY(L7B=&c%NNMYi}E$dCJ#U7U9*gAeVN}l%lgNX*J-g_O20b$!ann5dnr%dQnsjVm4%b zSQ|9i3jC>3){d5^ygBZy``6ZxYqLp z3Za3ggql4UgX=2q5)mWlF5?pZZx`ScIu}!Ho`|Jf{%%_)vSz@|pu63xE?;4A#<^EP z-1l!tlp+v{H(X1T1U%HW37G%n<`wzc2=pq!y|OkdLO_$rQ?|#Gn@9(~aAHfS)nwV- zt|eQNof_OC=uvOOSfRuS3vT-?#VD_2tBA}tS&(_EGN)hia_i2KwIiEEmKrm@AlwvI zWy*776ZCtrnbBy?MYRcbn_4??I*;o&LOV+_X>nbL!yl(bH8aHcg(!t<>I4Lqe&A0v1U)0;g8b+J_33i|!8u(BnrCSalih>oDoH#K zsQP|9n#(A*=95<3#jJ!-w%?yyr@x1Bf7Wb;cU5U8EPF|dqm3wD7(X*Gj$ZCP%#&`h zZN?(Qqe_?A%`#}YZVSU(9)s$!#bP@5e!E_)DK2)NO<@hI;rgnE4-jIlx)8S#W+My9 zY}Q-^aoELEQ#woMylp|1q#x89wNdy{Wh1u$J1 z)YM0FFb6i2OhqsJeap~f^LC_It-O@ssaSY(A>GybzDw!l6M$GJn;QEu?nsfp2*A_wM%b|_QWsWQ9q#r8dES#84#HWxpE3nZr0 zUEkgnKZZBSxaQ~FAK8|i3oLj>)*3@7nMixOzHe+-b&~R#raUV!GZ0-njw8|1*%hau z?(A~U->$usLqo)FfD;6Z$p#j2MIGHHb!~Thg8YD>fB}`MX7Y#TlQrAl5}pS+P9A{RsrbY%|K>BAb*N~yqCdyTb=F_x+v_`zqg);^lLih;%09A;*hJVuBnLm4cH;%ZMy zR&y?;<|*M|EehnBNC!<1q)~0WHS0qd>++b1Bi9=1|55a|h{*P+LFMVX+1Z4W-p02dcYoP?D!`r?r~m%IeA zswu3zqVZOB)%^X>t!fgze)@0O4w|}OSD*{>$PE$~*jPHd24=oAG-)R82_F&>P<5B1 zYM(HIdo*97)ywdT!-zF4*DhSNxfr>qL6^z?z29n^<0#G3+KxehZqN7rzh{2{3{MK@ zt7uw7*hbYx53+1MTB10tN!PnnLvmnFyXV)g3XDZ<*{Og*6p<53zvu{fT$!YR@-r z17wxBEg77B>L^%ur3jTk3=U0h-{X7r;XhF-=nauKdQIEDT{57JFX`~0j3XMEw9$aw zdnp7c;FcbfHbLwA&kLMkHYMT%-6n(QW3g2lXKRh-v5X=Gh!5V1w%>|;fyuoW3j%QB zuBiFDQstb`a#4|E5$Wvj3PN=TAxs9dtw~jl#RNq=A3*rJ%zUCDf7|Is3@ zhIWx;5<0v_bVfi_x##cnEH~T1i9tvZ^^9kpQVSIh68xhqClXmE1n%>0H-A>|4w!-E zOSO*me@gsd+o)}QFl^1Y=LxcYo2mbBo)!mbl{{`xP$A#V_H7|86eItt#$9 zBpZ>{uzt9=ml(F=aI{u`;TOikWq{i62PtyBS1p+feRTa}{3(eu;}C0_F=fE_eL_K}UB=GVH{NtPaG8T{*GU z_9Qm^C{t_hKpSSqCiX^qE6{-64t zzsWRBp8bb)hbm*JQS$G1BX5&kZS=ACoGzU%vv-IbhjAmj;P@Auzx4Gv=m&9_V^-@$ z*0SYWE;$*y<78INxRPovhdgck=zj)$(`lv=m*RpZGUZz}Ykf{oH*p~KPhfvdl|JUw z(Cbz3S+Coc)m6GvIvz>&ie z&(Ce5T16JAfN8vUYQFd5yxE??XS1m6y!!CWpgU%i5)A%Cd_SMZ1(^RO4ZjblCT1KSh6K`h@~ywt%1eq(ROuml%W5s9qp;UQ*?1 zO-r79G_0GfzxeZA7iKALr9i-GEF-Y8<^<)rf*`LiuiYIVu8d9Cn?60K@nQQalc^D> z(2JTLPS7eM3H7+w7wXKfExI^JXR#ivUINr7MGVAB{`8zwSUoN@h z|Jk~rntL#(dZV4agi_lQkX<-yu?1Q##d3tS_x@Fp&MtYz&V4MqV1XK$?5VVuSv9r? z-8B>;Qpp?v&~U`(usJKV=~mTP3DC&}LJow(mo<48ix6`YXOSiV_l!>Rxo%ZteZ;M+ zu34S0|2bf{v|%HZ{_jSLg=THqqEq+RF*)u4+e4an!V-2npHhZKLp-0!k5sc@_;Yp{ zfZ5xbn&H3tFf?EeCz8>=c;4_2U& zd(sEn$f%F0&K`^O$Nzqu9Qi5unUj~OJ3ND-pau*KOXr5{_`$7_YZ>?OkW}LcY*yuS zLbd1f4=ak@mzIE57dd~Zf>pp!7@D630K8 zr{WMi^(m&MeCFM?8g~{=4NbjE%K@G)vy=RZ^`sr)@|LRCqM+8p2*x9geL(|Kkr&G0 z#V?iLZ-P*XZ03j89ags`&Z-Eg^k3S~pY}?@%wc25Mv_s@baA9ka`Ut0$*J2*Rt)}Q ztn@Z{j8)+V36Z4NUTcVQIA0fYU{_1j)F%RLZAf|ona3~gq+ZBX_#d}bO5Y&Sd$qTp zVY!}Zk5c38!RyH{o=3?mo=Dg&ua6`D5j>lX;8d9nUn4Dl=qDaLfhXs{T@Ez9MWR zYK8`XCO`E1;Aa;#6PyJWQfzycyNz{q+gDDIZk`Z3TZH?GJu0c-vHfElLdZ@qw0)+R z*B=2U*vjdNnzTG~->$Mgt$nQk#q(o9BbLX~*k0ps!)xZKg39OAa#jRLvEFU7PucAZ z1Noho{uCk+j<5?Qm%!JjYff}zJR(0(S^a;OZVuR@n;FJRSo#A;eoBYGQ|hiM?FGFk zL-nlFKhkH>%}ZXa9aQS8qE&;rKoxaA1xb4vuhAW1Iv*5MyKw}a0!+C>+7bil%{ZRr zXl-knpCcz|4=q%g3X~?eJ7f#vKB9NsyCP`0C3Nm+CSWH9JF<>ycs()|Rn^-SJ8pHds{CM~=;dY;4Y17~IuuXl7UXs1rgh1!neCdi0(ad-~u+b+7t%MWC zWiV|$Fm11g^gGKQD82W@??W`Y6`#IJY`;>`@qBiOwBFLw%?HUu@j?Qt<~q1Qxbjkz zyN;*Ko!~6oks+Ms6YuW3%x~&S&ll_dsu?Pg$*$o7w}p6Xr>G%b_BPH?-Rr7Mjk{fF zMlH5T7}k7IWJH}kiGx+CVfzGWV|E%Okr(E2n|&#AD3?J`$ZUVbrq@S)Km8Nscy_jE<%~jxCtZ7;rsix=717cO_A!~xClP4S&E-?xj3dS+MVp`ph=$|s0SLH0 z4s)H?n9u5J3k_5BCw>xXO3`e1>z<(oI%XEk2Go|T3USYfpr zSW}Dv3`s3{K|Xc%*>XPA`k<9qe07#$;X+iY={#0~^`1O=;Db^pZ%t@V%3EEu=Rn!` zzMjB}mah*G!Qa$=|BmTaTtSZ?+9+FFRP0aZ(6thjmRP|va1)m(8j(90K4rDL!LvG(K%;V8e=j3Ac)WQ6Fak50NGWhc$e0 z8DVTv!8Q^VReT(`57)cFd>llVQzgoWV8;KRA5cz#DW>U-D!x$(QQ`+%qVDewV%Lq% zDO%rwNwYxil}PU9QgVt1+TVp>-|cdalXEnBrA8f?a}hG5!^kLC?lvYYwbp zy~VD$(iR2Y44@;e26ifdom2`k&THY^^R1q!4s!o2n{^So^$P+MMqoA`GR!eYc#>E=*Lt5xgd=ym`nHdpYF%LVIK@idlZ26y52EQt8ai&6#U8RT(GtIP!d1 za;B7^M{tYsBv^DpIyoV;Np`p1B0q@_Y!>c$ zl~Dz(xmoocRBtK+*yTJWP{JW06nmK<6?w$9;{Q+9FRwEKN`tHg+>)87XP=wrD@^Z^p^V#GeupyHZ))|(yv zm+}<+U#Q`hHM!|w&l}*Uw!^&9%Rc>h8e`4ZSn;jnMU1XHl@ss+h{)3|r!S)T{XM;C z=asLJJ><!mHi`R`#QMDc z)Y1~)U{M<;+^F!lgUx;r%AOheX-XUdN4D?^*4~leP0l4(Wf{(JmuuACrlyRAd>c7o zds?#q=IUeP39EnETJm8_!IDY8^!i2r?rwZPLka*JzmlT+TVWtRugbaDb3!5Kp4Cl7 z)z16JXpeIP`A~J)Kk~Mb$~bFs-QwMe!o!;hp5&MbWN(<)H-1Y0XNH1hPyZ*4r+eiN zK686)YJ4yQB^G%518m1Co2?ZQWuV+ch=gTLR7-UA0B6m|nlo3L?)-;XO4j$s?z!?f zve0^putQ)#h+9Yj+c^8_m!wBBcEMnQ6JMKY#}}}<5l!RQVw;%Sl+u}~`tt@emVU9; z<@ut;W}$4DjNDpQ>EV5H>J=@~9-^3!SNcN*<#$T3Ir(eaYpLPULI>ZusCA_Y$;_ zP{Z|+C~$0ZtmQv7ZCIe|h#QWvsEQC)R# zV3wtmD~JDfW}hInMK8g^CS2tiVVrZFS}G zOidi;uDoeSR;h97(}{GAL5=ydJvQfxb2ZF{%1MF5Cqm=CJ6)Fw89$6MhXkov9^o}Z zV1($F3thPlt3Q_VnXzn47A>VS#C+BzGJWLf(xqBQrcxn3>Y)@0wOzsH!Tgpy^XOWT@ACvOpJLWdr zF)30)c#m1~OHtyjCJcyS6qA)SFeenwxq8#~{{))jO@w>!SZs-Hh@U^xp-H^M5?jG< z+R62}Qq56p3I2>G{>4)GOY`w`idx5-XF?gw3boz;D|SrDxYW`A=o1@twpSC1UA@e? z!{W%eVpND-!502if5p@`xopCE&qKF}I@2{0fS!~jy^9Frxsr9lPQxvvKTY#-w>5z0 z{cA}ul3r0LQS=>gPp;iD>k0M~-|1Zp{zUaReiy8GtImle$roal5I&uF*zKnvpzBb= z395u2tmJQnv+d(o-S%%+8xT_gkAx{=8s$qe2(^3_xKnM)R|u1U7$D^+%W>DMD?Ycd zH46=z^lN}}vGR)l=NQ)n1%LZ~0IkQYWneH%+#WR7BIYumjk90Tk-j{1E zB&pvE>+Llx=6MR{)3s``F$I5@2g|etZ<9wOeY5(ahIkkRrF|zL&5@SJE~p>^{gOf6 z8FPW7Ssc=1gUN~>{q$JtN-L%M#G|cGBP%**$zQbyM zNT(6?aVGJ zYidql-p%1EUpMQx27UniuhV}{W_3j%FyFT`ZLTd+JxlJn(CK}r3dj^A@q!qhKXm)! z@jp(b#`=39A*j61IZ5!me~1gRC@JlBH&S9wMwC-k2hjEn!8X z`*-cJmKcbpIT!>XX3N1YfJs8yijVnJBgy}p99VFy+bt(=z;8vX?t>f0$Pi=QYmV}*N-O0IoQRp1K_;#7fX7#99IoN#xsZ7vTNtq z41X?+Ub25lCCPPol8v|$=>m37DBw_F$;&e70`^O?h zdSOo8tOjh80lMgIa4_c9y`8T_wh>1!%YGq~EvAp0PxBOVAHA%4bxam0@gs*Q*aL5V zJo?iVI1tYHm|P;K9E*WK7yz|i&@8NL1(+qJ`dnP&E!Vc*HAx8?FQ@Gt!{a2hfPEn{ zgS0yf_)-FviiaN=5fVSO*EUxZc)5+dftWHSBXEJ-$YG2>zl8)}HYBzHy8+>%0IEn(g#+7zDyH1BuKlXb~mB)^L1mLCY{LIDdo5!A+VrUSTC^C6IM3^Ox=>=ii z!s}tJCMON?XFq5P2F4D<0#Vhnur3r})=OVGQo1x$jj1a(u(_Mwno%c8e-rfz$gv1p zct9$J0;deed(cA_{f^m9UH9s|5q z*nhZmm;vpcf+-wQp?Ft^ywY`oeZXxcw~V!z|tA81+hz8tIp9h#Ss1R=nKnSE)e~UiEOe$Umzke&QhkJ4J89nupdNyi&!(855|5K5IN4Cn9W0 zylgEk-2<^iD@Gq_Mdng`-jk?x#oMDqSYKxWjm9TeB#X2S>PbZ#wz>nKe3ubqm=?dn z5-$v6B&>*Z{(W%gcqvjF&!m4(K<#0Zip{9Whu!LkRPP=OULhRZn$u&JhLoC4UyVd^!&aZGi3)mEsL^UjT{d8 z31qrOK|HPjQlCOWFnM8yd`%-nT4mKPrJdb^{t0xELh+N5BG5;pUV^u!7=_{&Q@;w`-UaJij_5=d=v?y9Js2Ro_$7~8$o*?P!Mipag(`>B?}fi3 zRw+*lBxQ|z_(R`G#R2q{fH%F~)~LMlu^E@E|xGj~Sv z)cHz*$As%ii4onumgjb|+ z=&wI4C-p8r9ml@spqU z)bCiGLj72kIlpfz#5okn?dJ@$@&42zesRMRds4)%{sk(4jX>&>aML9ak}-IE&RAC@pZsHgcu1m?qUh@S za;}gkqF`NYjoc5`un+57-V8k^6LwKJIj8{@Ds;SppKiOgCwi6Tj_lMLFgM?@gJv9Y z=#!#y6nV&@@^$C4Kav`Fz$X#$b+D({TsTw#f7^c7+Sz%tvlET{BVT7cFxYgNwcoe# zM^9SB$o<0W_+ZO5bo@W^1|MSe+jN4+SWF`SIvl6*Ly*)4;E72W6G-f=|sb5s!S6;kl0KOCL< z(R_|^E9ZMipeQ1#kyvr1;25#<%OITF6f#;P%pz|Trb#RV3@*MHu45(P`%)O$KCZZ` z3}QIEJVka%f2^*`(inKEbSFAi+?ZO`XJYJZI$VmH$xviHwkRC*tonG103g$h0n@Ps zRi~UHXa!M1!<>W?#%F%YI6V5cXG>LdNJ;y|s876wci4-oY*Er$7}`KNH9J%Vn3*B4 zu~KEQe%q~+KJSqg*$ptLQ&4#U|1(u^Vb1I>eNi#{4Dq)dC0`~T7Qwt1Y6KTCefQw# ztF=U$r3K7%t8lGSiM-O}g$w*nsijv<4pAmGY5dCHMN2*=Zja4l{S2Y(Mta3nrTV%N zS(+YCS;KMCL}s!iSDyo0^DV+B*{!o^k)PG?+WGdF>OTUat-&i@H4hjfv(sPauaU z9{TL;pC|GZ9Ti$aT=6&bv2~eOYq3|b4uji#8Teu%=yTsqR7LP@Vz!37nj14>;RdNh zwI!j&`NR!Q)%p)pPfM$Lh^7xk!La1R~wrhQ1qnlZ6o zugNFc*E=-abE&8&43OJbw$Lv4^rTPTjpE$l?t6ZO!a=Rn`~0ZYPQQ%VU2UZc;i~bx z{-OQKK~-~eO#0VJc<)U0P3Vd>-<#gWR{ZKfGH(vG(?fT9IA6;-YhxHMlQ3)SN{~9gyNdl% z$SDi|9Io)9bp?J)*6cp~th3DL!~Q=L}YkMv4dl=5gF@ zxj^raLD#8wld!VBy#8|((Oc8|JAeu8ptiXSlJ%oPqcI#x5#bvW9#PY~8u$|J=Cv#f zv^Sh#0Q!B^4T(wge5(UtvBdZ>Rz#g>+cQuD_PTR)-!JewBhO29bg~eo9(0S7HB+vK zpV;SjTO+TQh#3u3UlG2=^BWN-Xkw|BND8Pb#|5Ody}GTz=gc&Ot9MN-=!$7&7I>Jk z^Qh1N;C-}WmwYEd{1ZxTnyTx+H>;=Eq_A5A-X`D8}rwD-uP*F2> zVt+W3_OoqdD$HXZrUDEwnVWC(`<-GeKMhZN9t`F3dtLPN#qf1cK!a3j7{y4y%KhB2 zM7F>;1@=17q^t>Vs68g=C+k_)QnB>@r>U=iiYr*Q4HjGi1ef6Mgux}aYk=VH?oROF zAp{E^f(Cb&0S0#sLy*DU-Tuja_x=A_vsh;hoSvRK-CbR~cJDGrv3n28A-!YBtETUp zS7_Ezh24LgorKfs%>5pkO#tyCGxH0I8^1V#O{nI@%n%LG?Jh~M%|+^`(~1TnX|fe) z2xxIfwFrcg05NoK-}Zz0Um+*VG<`xdRy0gT^V5cH$cGHOF>7o6AL*JCHcnu=J4q&R z-HEBpN<;$OY_D-F#K8RBLcHiX%_wqLcUBO>)9naQV=ZtbYf%TqnkAeL3PUje1e|fB z*47uUC6$mhbAB%ec<#k;UdW90SYHu$MD08b8LsQ8H&@_sk>-~MlU>8-GRpEYC);H? z(@r_v1v^Hh2EES*c16WIKR}pY<@#I<(>$k_|2c*a<_U-%h$5AmK_7tkuYf0LF8C?uU_A#W_0D$}0v zW1&SjrkA4vzDa`>XF;vz{m+-Sp#_P*pY0=) z7CeTFsYOMtC^Qpi(LI&)Dl)70XE(JkS@(3zZY9prTxboaYSfb`oYpnpUrL|bpXOg! z&Sf+L)(UTV^B(%rEXyC;DX~wa8s<|dhI6aXv762dHQ_x;OVmo&C0~G^FVHbzU+qp^+oW%$ST4 zSaxTm(D2o+biTX*9yh1XqL}xQw`TGy{{sPFvZV(M!5nHb-?Y>70i0W1R>6hXpi6!J zr?whKg5qML1t#WbK0${RU$I|n$QHs;M|b!`+i%%$q|q@PzayT>ln#&8sxD8_W`6{z z3%z=>m6l7hJ}q|XxqYL{-Q)4*Va($NtC3Px@wEp#_oC~7`@0sU5fBF^r5bM?b3-$& zuDyCD)3SqLHg&cNoErPb+tWqY$k(q=g>T%CnzipyVkYw5q3DOvC68~LMK`d)_Ya3u z=D19dEPW4EV!wwrI{#5uDZ=xMs+p4>edI0E1``XuN|2mZ{3>$GT1bK@k&-dVDS-ys zUYlTF6Hx!W(pQ?rtpiin0UagYwH0Ah|uJb+qCfWimS~Z#O2apGNl0~!Vv?*@>+(> zK6P)wkH!aIQrmPkAR;JXO>w=a}Bj@p_KNkp6T3Gs7He7W368~Ro7zAObqvu+K9HDxX%x=m%`Q?;3c zOoFuuwMi3^+UmHSkygEWF2Y)6yBc{D(d-U~z1$nLHZB7)#+wzL zr_Cdv_ik}L3E5$dMp~^MbXiW{3H#y&O7?@+Zmu>qP_o>0@9 z9Dho7iUL?6F=-HmJqFj&Oj^v;Hd_Q7)&RT7L}wz{OlCz=>{&#Mn(nt5U?^R^>?TqN zIn9#S6+t*8HqOra7mHh~Czi^Iog{SX)8kFERe-jRMx zKFK8)V2Zl~lP#%^u{)6dgc38eSREN}eU@|orm>=pB1vIg)hH}EIN_(4p>X8V1DW0K z`k+7?XT-}oM^DxIVc=_#0Dgr8lyuoni9gsjU9K^Y-4taVLg^eSEsbiXwa2mR&Fn3I zP`T-8cbDY@&dIjVi}@dNtvhJqc*jWRub2TVN%MVd{xhem$U>FTgd~B}$y(7#+;rH2 z+|Pn5ZY+zOF^LnIYG<>U6Z6KNUwDJ}iSE^)7tBIvJ-@i7(D!hezsz4~SW)UB3fr8X z-4tDOu?AhYa4I{D5fsU!vhm~90*Ty7_Z7?H0Yf$6N|^39PN?*k*IjX#fgb*1wA}o4 zAQYn)S>}hXh?$P~>e?^f=O4|2GKNGrzNWWJUH9d#JtC)Hi;gZLEla(AI>gR1uinb| zL=UgKcunhHcCl#QMSV4E{B?gN@_^XtH{6*|+-QpGpkV-|m{?!GZ0A5D57r-+T(f(x z{EH_=DnA|n()WmoGF}(0;yw&k#uWq0@v>1q!6>2KR_LHM68n$Zyad1!or!?CNoJ_} zsk6vcBU>){r-oyR7IeC?6*cYn50Xfm$%JN;0fS&Yg(DLNCl3gPa7TTJPw52f!)Lb) zPK-f%T>eVgiuzX!&vN(4#<)IATC5S)2$%UM--wqdQbNLLowD-nx8YkdbMrNQV-f>M z#X#cRC};aW^2%yhwaYA)fOE)QA5PbuPgh?A?kBx>j}MDU5t5U`bwnZjuHPPVi#f7Q z0-$ejFFvhP2Tu`jT`5C19$H+Md5oMBwZPIhbhr3T0Q>Zm_L zyCOP(ItqCTyf*j zpCIL=m{hU6I0|_5)1Ld@eWG{d+eyxOW~f??=-=xul|EN%Qe1X#)J?$s&#jcCg&REo_bTX7s zwVe=Df`9NsXmv~T1JymBdWIenre1)oUgF%iMozTV{W0T)z!;lVf zdpmnfQxfpR9Tw{`Z5WBjn;pVq5mdZOpHvSOFFYm3jsW*^+pxKk!Vuq=haQg7f$b^P$@~yO ze^tOF%msv#g#zD=A$0@`da(KFq~l-n^8;qA5>#SR*4~uzvLN!mnKy8V=zeKS%lsBx^g;UAky3^C~>g1XJHe`YV+*@&B}5!;dGiUCmKH%4YfZ>(w_dC2Isp-3s8^@t?!LVN$^T98={{#mCfESh%fhpMSljFgVpo&`xUe}u>bg&mCgdG}M6B*(hP zNgrU!4H#nw?HA&(kIci<(=d3mIk)wqKce%q&>d~DWGE6VPWYG<4q{wZTX)|qhR9#T z!$m%E!IYnR6{Zz*1hMP3k%tiVtSzgizR zX7oPhm~OQ_s!Jt)OUF}^|I{SLmm_uLW3lJYu6N}vI-&eLL@7~$QMZBEz${pjCln>O zEtp2xw1dKjQNy#~`iO>aB{z{^HZYE3z#k(IK)CSvqztK}CQX#$0~)(9TJjm48dFT3 zbHSjs4cjQEfr2s-G)d$k6+^JL;&(4$d8bE={p6SKODyvp`Vt6FpWr=8BPBI~YJR3D z#c!LueckTSgpldkC4affIZ=9pvv*f+V{^6WC8>c{cW0q4xwa^1RUOZK4C7hmCpm`g z!FPB?Ma6fJg`BOvQl8JSQkq0lAexDv`N`zSI(oqIHO_q|D$@{sN*CCFp%|l<8~Ueo zTp=}6tFOE11*n%t=;9W8GnA%SvT=}Y@s}KvPc4MVk>aQr^e~NWjH*hot|)QuHs0_~ z@Lu5I%(@U*b>B$XYq3hVbILzLh+Rasqq_QUyJ<%2F{eeJiEH+_>@q008{pTiWlL2I zB%3@XgIKGf1Vg{j$~kwY-I8D!dsg)5Pg`#&$oo8qg*= zRz0khw|{i^56N$-x)MG^dUMPX-*N&r%TQHGIQ0&`AjH7Dgi|xnwQvadhAk>8v3ymc z^;`I2*3@9F16$10%L!nM-_!&`U(M;okLXk;e;E*X@0J>%PpV2T5PncDQn7mZ4?+j< zR+fpKj%NLvDmks@YXKJR9i1y>o42~Dk7`+d^B1qj_8Q+S`^q%UMZ}HgzqLH>#Sn?p zj!|k%RwCjSs*f#AfODIJ|6?GeqT8F%lHh;MxP2&uXoQwIWX_Dv78=n{JH6(o-Z6}n zHBOaf5wmz#z-lBC5>NLosTBP+nkGAL5wyTjJpwgZ0GuH4?XTqb?V@V;jpiVvddc)* zT*Zd0UGvzlP zL^~Bh$B_}E2Q zT9@a+dmU7eVl3)hhm2A_FiQ|kHQfkR#^dD-#yGp##*cZ=ZF!Z#_jhkVVHYJ=4jhJsq80&ZyRQKL8*p6T{_!OErE9H~eZe`-)ozna zRmAm+0bg$tP1ItIg)@j)#FLntJl*;kkc(D5TvqqFcq#mLM6YC_&ez8h z2VyuKCV5mMged2QGOxgBnNdovj*HrFkj^tOhV`_FK5NfzTcg-phJtBBjA>@t@QU55 znI0MOq~_Sh^UUM{m79T0x}R^iI1w?l1xi}+D&bHh)$&N;O#=J>Fu$eyZcw`3T3DgI zY2U>5;;Fc})9}-k4bSK}Jd$4dr+D%2WGS^6C4vS!a>kDH;5 zDF%0u{BqP8ad-{Ai>4ojzg;> zn>eEz(AxPErhO7(KcdO3{F?3}Oxf$6E!*w?#5lHbyH(<3n!=gda~y=mu^{hrL1q>7 z4M5HR(4=Ce*qgc(PKNJ98@j@Y?sls;{P7)KSBCg^Dke5xK409CUnOq zkGmlFjP9Pfe5xQ3t10+-e<^aNf$d&{P4@SVk=dQE>8lV*<EO=fRZ#I$Dg5Fx+lQ70tvYY%R zx*o)qzdkC{&;0~Ghi4rx0F&4Sh1&VPam4h31}Rl$>aunoz9H%kr_vf}E>#-Ifi~IO zV0d}qe{-A?_-)rd>Q>h_@0J$mi2_%#FA+`J`JrN9RW`VAoMoNS5ofm@|B`rmDQ3I3;`ah959TL@?vDum8Ie2U4%7dQLxYLEcmM3L#KxaIF9qI6U zT7_r)fls6q?lzCQHm}%<9-yG(jlBB+SHUZGH#O4!vR*toMmYE@IqxSK!H`sFfPQL( zM)XdO;AnCbHrt(aKIekUx`Bv<+Eq<<)LKU{&GHjQtGBhs!+fQ0n=w8RVH^w1DPYFF z=`lmsteBllkfe5{%p!|G{!pmV<7$8#eM?Tp<{$$+%*u_VV?e9u)hj}3c_|4kC8rQR zBQ|}1no9>7vQxGodbZz7etWtpqm5){Y1)0CL~}hfTHf|IY)-v$G4NQp-kLjyeRVPp zGHT7%)#vbY!kDYWdu#9OeeE!Sb(E-G&=97;`ey-}FY5wi`yLP3%JGmfRCf5XI#lyp zBdgCaF5HpG7HIUdySsd{Tn<_1qfKi{_Tl+b29zy0hxovn&L%UyafAc&!)l)#M0ucJ z__bIHG>-b*vQ6N@JkGLUoxhtykePG`2EJ7Acmi(fPIV1NpQie3l^*dy^I9J@oVgP| zsiG;EB+l66EXg_l5ls*UC7C!&i!}Pt%~7j>TtP==!|?u#!U<<@*@#tVetl{t>ApgsJ z33dN1&^f#p1;>W5~pnSKPyJ~2c;u|jBu zk0CdHQ02h|*;G(DFqyr?)hPO9ex0Aq*=(sS)9T&h>pso=h`mlx7y~DE7~bj=p;@ zO}23p0;?}gH*1@%G^Vj~qy$<7sN=z3lr7|s1>)3IHq(tKepX0Q>KlFeX0a5>5uSi6 ztEVQV@_@AlfO8&^3bsiW!g}4m$5ivjI}$Z>2JYW_OsM#rNAWW!w2GqjzTn;OPM}PX zOL}5RS*aV-?HqA@Q&(GDWc4PvM{6(d1ec`B7mtcGF^(+hlK8UV-uWw#R&6HaL}&yY z`<<47ePcBPBnxUa{?tMqZHAmh?|(H!u4Fff7?Fmd*@?ps{I)N;kJRA+B@YNZ4m
  • _Oc}O2=c&ccEeaTRz_>rxd$ktm_}(ten54x4oW8qb?}- zqor+ldb1oj0h?z|pQ$oely;~D;g;3<0#GhzV9IG-0c};Fp7Hv;=LHD;%z`(XBTfby zyYV@i{w_c9078}1m)#lH;Y;EUOY9@#z^xAbJUt1Ai!CDUc?myWnCSAv&Bc*`X?Lq5 zn%{Wz?rg=75##xmZaH=UWto3_(mVQQ?qYP7S&dCIvo$6=vzv4g{YZq?y6VguYTqpT zLa9`SAePaa*%1;`thUs{I+}+SKyJoVqV1}TvL|e3!V%XVhPMAqa4=Z`%7TD@$iXVK z%DBHea*;SvL;~8aWphhmmEWpW!c9M;V$PM(EDf?L5RsH)mT22&_B|Ul>w_m>O+zwC z+LqE%OVl7g<38DKK%VtgD|46ji``@iJD%>UshMYwV(fjRF(?eqOKn!?FVm_oHUJ8i z6dp;~^pC|giVx=9%DwAdx<(-T{C!C$oSY~9(Om(QM@D>?GIGum2Zlu){jB)#_z2>1 zV}LnQtRcShpk@fIsWZO?(2^~9=IjBR_?$L**gk{HUSPWxp#-;VgCmZgml?tKDMXaU z>k83tldZgT9a4MSVzfNXI$qq>pX)=!kVD&kO>E?vrOT?kOQ*J=isK`{yictH-0coZ z$1(h^HsU;hAttfRQLyr0JM+STf;rP*_jG6P@Yh$rBDrBQkdT!ml7idBww z<2%b+$Uw2hNB>w@>r`TgF006H8Onha4eLsP>AHlDv@U}I2$uvsq957tt!efbh@J70 z5Bi2LR%chn%D`__Ha<6MZ=p1Q&@!N0D$6eLo@0trq;9P~q?j7{`2*b&C|Kmaio9BM zm$x}X4liZ5|25n)5z^7mR?6sl68b(nWR5vm>Z#~Pss<%cP%Yj8fB3Mpx5s+_1tT6T zWoJ?%;$2Q*>cIfh1P+IicD6jQg;wU85i=6zd?Fg0q(b9s6~!K5F;86juy%W zaU4>WMiI?SM5h-Cv?IAu>)VSj)XTy54L$qMw#NKroBOVSke3x!ie>K}m^NO?mfgG* zwTiX1Xv`MNZyNoWLQZ>ivXfb6ELre&(G(=b4e7HKHx)@N+c3-wwc_oNg@$%r6C*mh zI%%=SG(t8<+^PqvW)C|x4PN7!K%x)nj?h8 z{6R!Ef=I>8pyefd6D zXxXcc`?zS?Umc>&exow{MHJfV(N&Lf`ff>XDM)0*a0PO3m$}4$^Ek(xw$1!kgS1X! zpMERoj$04v(Gb)=b>d6&-5+FGuG8Y}Hc>v62pX$oKBQ(Vow6zk;-RRchts%4F zj$hrgqA~yEKx6xG#Rjr7uRHE4jrW)^E_6$QpHX|Mnp$X@7IB_;>yjgmKbikw>Dd11 z`2X*YSA0;rdd`*=%T3R`<52sM-n*P%KSg?oa#^|-Avn?!{ z52#!vvb;G>c(8JBDs_Aa7cyhI@UjTzKs|3{|K{s|Pbm*ec&b+U2C*^6#%5!yU=~joLw)q#f~{axKR7ofWHH zjHD2Jh5tlV0FW3T-GQK4g8bhw>tCexT$3sTGqPb>{1V#V(aL4ILCy9!*R7{WM6x$y z%=N*X6SCZs1hxX+vH3+AJjXem*n}T{bEdbnUHH=LZQ6n4CzGR_FlzkDqT_o+nh7G~DNJ70`9O6lLFHU!Av^fToyMI9{J= zn}VlED$=_2_(nvW``u28;$$+TtZd1&71npRo_q>cFJu^RoQ8F)+(=K>Q_L?MA6%O~ zvWyBSom2RhP#n7`JngSH?!h^RXQECH_8sF#l`qnZYN7-;7kU6g1Uo(m$WCm|;1&Cx zpHdC&-@k?3mQJ{kMo5+w@jn6SqzdD3P=p>X(2GuL%}+R`X!j}%~LWj>H3 z$B`D8Kc0hv``krBIX5SJ5QJ|oHw@Z+Qgp!^-M!eR@kp|et9YdX!;utb=dK4hLUyD4 z;PCNNX+wW$^cu=!nWy zvW?Z?3uN3H{uH~Es^BAA3X$r>5gq5XeYt8JgzYhaoVqm!N^m;@AcuitMK9k;Yg<`@yhw!E(gk0wa|#__a$0ofCBm1u;OZ79gk|E9u2_-AXKtFR&a8$ zaS|MLcL(2fm#A}R{4%oY;_eZ8lNdN`{IqpLM4q{(^KwQ4m@=4}qgq*fgZyKqmM&{b zdQOJp@r)iv=9k-`S7^dIk=gMo4casqmUI|ix*l2Onq81J2DEnH?hQpVEkWHL=Ms(ctE=Vc z$*<uib8xH=VjVV<*5beSZ7BY(?@2&kgiwvn1VtV z9^PIqyGc-cCD%>tSOiD zrW;e~4)t5Lz_2kDT!K2cX@B}XF9*S!wG+)y`eDF_c@?cp6mC=DbFll+es}#ab_9kL zOl|iHHT~=Tw-GBJSPg?0%ATCBQ97C7tgnYyBJ{vq=yMcb(EiXffV~|V-u9Dbb$muu zRE0e=zp5$Ig8Glc$t3odeqjhb{oR7`V=4>`Yn_!qL|Ij3j+i1ZUvp~G7yBDZ#Mi?PlXe(XM$Y!-?^kA2*`^1e;W zeTm%36$#add57|)e}k1Z_4a~P)^~rJ)T_+|j|dEzsZV;=IXtt$!z1hG{H!H6W1_L!K6v)#Z^A z{8${kQObbze#|C3Ua{9fnMs;CW3#(6s@Y?UWL~d+V{O6i@zS~r7*r1wBt}iw_&pz* zcbXG=#f>>iwZ0ueU+*65n+0{sac5XG>e^6%gg~m+);Yi?Fu!_o*@zK;(2_=8&&Vhf zc@BmnBXYPqJBqQV^(`dFNL0YJXz!85A%Fxcs6svayT%ybxFyf!l}C~So$qr7Lk0^$ z{;|31G1xN4|J+sd`I1)xN3_}>=QvIq`PlgJT2=UAa{~kGQo@kO^X%O7^=rQ&i*jt3z2P+2jP&T4VM*=%eS?;h>- zc^W=;TfeNBc0L>t0>-^tTP&WA0>N32=j>cs?=S+a>lQzRLVQ3XR-s{7+!(a#vre4j zzm58Qxg<2U+S>uSiweGtda3ga(}>~Z#g9Sv*o0<65!Y4D0&^EugAx!) ztO8SJaCBK1`uZWWi_mn#q(a($>~LAwx&2A>{p!x)U2OaF^#h=Te?Iu!(SEgCN5wX6V>hyII*#@tgPSv@U$lbMsZ3|UPFcp?D*>~!p>5do}JB*HUbkIXPZAH zWtPaPZa@jeS?IV>vtj|;qlZ?bHbaBhQNR!y3`ta~_6U^FSKj0w+MHVveZliRda(zB z%5%1o{D!3phAkcIKC}+(aSiGlbGD1HCMzSOJAE8Jb{6F{lA}LMk?|ApL7lZ6piWZr zDZQnDg1R3qdiHIrm4OlQBW%wOp1GxS_ca^*_6ZBqi6HE$#}Px%XHT*F8WaL>t4mvX zLes;W5k))qf)u@vf9t2ubQa0#1#PRAgFv~ZJA!ZcM4Rm0MT`G6b_DiYtrS@xsVby& zrZ~2EZpeAk&H6|F`j_yKXlC=?Z|5NSM-2MU%V{LrPCyDc&G4V}9SXQi*PaOMO|8on#Rx^s@vZ*N5Vg@9; zljLY$0g~c<8)p+DF9bc{#u8E5{XCL+d1~8p|C4(Z=Z(@_iES(ORUG;n5#qH(Y|JO<%!nw8^ALcz)uzFoY_G$jt zcJ#m2V>?Lf(Fp9b8UOoU=%lzq_WxKKP_>`_V)DODfYW6DpLsiZssBSh_-BaI!n+r^ YX^ecKoFA3KuYezUX%(r;4<@1i2dkb=i~s-t literal 0 HcmV?d00001 diff --git a/docs/mxfp8.md b/docs/mxfp8.md new file mode 100644 index 0000000000..aae8f4a7c6 --- /dev/null +++ b/docs/mxfp8.md @@ -0,0 +1,190 @@ +## MXFP8 Training on B200 GPUs + +MXFP8 training can provide substantial training speedups for models where the majority of GEMMs are sufficiently large. MXFP8 is a microscaling format from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) that uses block-based scaling to maintain numerical accuracy while leveraging low-precision tensor cores. On NVIDIA B200 GPUs, MXFP8 training achieves up to **28% speedup** over bfloat16 baseline with minimal accuracy degradation. + +> **📖 For a comprehensive case study of using TorchTitan MXFP8 to train dense models at scale**, see our blog post: [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) + +### Table of Contents + +- [Requirements](#requirements) +- [How MXFP8 Works](#how-mxfp8-works) +- [MXFP8 for Linear Modules](#mxfp8-for-linear-modules) + - [Usage](#usage) +- [MXFP8 for Grouped GEMMs (MoE)](#mxfp8-for-grouped-gemms-moe) + - [Usage](#usage-1) +- [Example TOML Configuration](#example-toml-configuration) +- [Performance](#performance) + - [Dense Models](#dense-models) + - [MoE models](#moe-models) +- [Composability](#composability) +- [Known Limitations](#known-limitations) +- [Additional Resources](#additional-resources) + +### Requirements + +- NVIDIA B200 (SM100 or SM100a) +- PyTorch nightly +- TorchAO v0.14.0 or newer ([TorchAO Installation Guide](https://github.com/pytorch/ao#installation)) + +Note: GB200 is also supported but requires building torchao from source (see installation guide above). + +### How MXFP8 Works + +MXFP8 differs from standard Float8 training in its scaling approach: + +- **Granular scaling factor**: Instead of using a single scale factor per tensor (tensorwise) or per row/column (rowwise), MXFP8 uses a more granular, block-based scaling with a default block size of 1x32 elements. Each block of 32 elements shares a common scale factor. The data dtype is `torch.float8_e4m3fn`, and the scale factor dtype is `torch.float8_e8mfnu`. +- **Native hardware support**: On NVIDIA B200 (Blackwell) GPUs, MXFP8 GEMMs and Grouped GEMMs are accelerated using cuBLAS and CUTLASS kernels exposed via `torch._scaled_mm` and `torch._scaled_grouped_mm`, achieving up to 2x speedup over bfloat16 on common shapes. +- **Dynamic quantization**: For every MXFP8 Linear or Grouped GEMM, activations and weights are dynamically quantized to MXFP8, then a MXFP8 GEMM/Grouped GEMM is performed, resulting in a net speedup. + +### MXFP8 for Linear Modules + +#### Usage + +To enable MXFP8 training for linear layers, launch your training job with the following command (or alternatively set configs in toml files): + +```bash +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \ + --model.converters="quantize.linear.mx" \ + --quantize.linear.mx.recipe_name="mxfp8_cublas" \ + --compile.enable +``` + +**Configuration Options:** + +* `--model.converters="quantize.linear.mx"`: Swap `nn.Linear` with `MXLinear` to perform MXFP8 matmul. +* `--quantize.linear.mx.recipe_name="mxfp8_cublas"`: Use the cuBLAS-based MXFP8 recipe for best performance on B200 GPUs. Alternative: `"mxfp8_cublas_rceil"` uses round-ceiling mode for scale calculation. +* `--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="triton"`: Choose the kernel for dimension-1 quantization. Options: `"triton"` (default), `"cuda"`, or `"torch"`. +* `--quantize.linear.mx.filter_fqns="..."` (optional): Comma-separated list of fully qualified names of modules not to convert to MXFP8 training. + * Example: `--quantize.linear.mx.filter_fqns="attention.wq,attention.wk,attention.wv,output"` + * This allows you to selectively apply MXFP8 only to layers that will benefit from it. +* `--compile.enable` (required for competitive performance): Use `torch.compile` to fuse the MXFP8 scaling/casting kernels. + +**Hardware Requirements:** + +MXFP8 training requires NVIDIA B200 (SM100) or newer GPUs. + +### MXFP8 for Grouped GEMMs (MoE) + +For Mixture-of-Experts (MoE) models, MXFP8 can accelerate the expert computation through dynamically quantized grouped GEMMs. + +#### Usage + +To enable MXFP8 for MoE expert layers: + +```bash +CONFIG_FILE="./torchtitan/models/llama4/train_configs/llama4_17bx16e.toml" ./run_train.sh \ + --model.converters="quantize.grouped_mm.mx" \ + --quantize.grouped_mm.mx.fqns="experts" \ + --quantize.grouped_mm.mx.recipe_name="mxfp8" \ + --compile.enable \ + --model.print_after_conversion +``` + +**Combined usage**: You can use MXFP8 for both linear modules and grouped GEMMs simultaneously by specifying both converters: + ```bash + --model.converters="quantize.linear.mx,quantize.grouped_mm.mx" + ``` + +**Configuration Options:** + +* `--model.converters="quantize.grouped_mm.mx"`: Enable MXFP8 grouped GEMM conversion for MoE layers. +* `--quantize.grouped_mm.mx.fqns="experts"`: Comma-separated list of fully qualified names of MoE modules to apply MXFP8 dynamic quantization on grouped GEMM operations. Any module that matches the FQN will be converted, if it has (1) experts represented as 3d nn.Parameter instances (which is the case for TorchTitan MoEs), and (2) a `torch._grouped_mm` op performs the actual routed expert computation using those 3d expert weights. + * You can specify multiple FQNs to target different MoE layers in your model. +* `--quantize.grouped_mm.mx.recipe_name="mxfp8"`: Quantization recipe for grouped GEMMs (currently only `"mxfp8"` is supported). +* `--compile.enable`: Use `torch.compile` for best performance. + +**Important Notes:** + +* **Token group alignment**: For MoE training with MXFP8, token group sizes must be multiples of 32 (the MXFP8 block size). This is automatically configured [here](https://github.com/pytorch/torchtitan/blob/b39377f9fe33865fefb9bf64a33f6d74a598be87/torchtitan/components/quantization/mx.py#L131) when you enable MXFP8 grouped GEMMs in TorchTitan. + +* **torch.compile recommendation**: All benchmarks in this document were run with `torch.compile` enabled. We recommend using `torch.compile` for best performance. + +### Example TOML Configuration + +Here's an example configuration for MXFP8 training in a TOML file: + +```toml +[model] +converters = ["quantize.linear.mx", "quantize.grouped_mm.mx"] + +[quantize.linear.mx] +recipe_name = "mxfp8_cublas" +mxfp8_dim1_cast_kernel_choice = "cuda" +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.mx] +recipe_name = "mxfp8" +fqns = ["experts"] + +[compile] +enable = true +components = ["model"] +``` + +### Performance + +#### Dense Models + +Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, torch.compile, FSDP2, per-op SAC: + +| Scaling Method | Peak Memory (GB) | Median tokens/s | Speedup over BF16 | +|------------------------|------------------|-----------------|-------------------| +| None (bfloat16) | 33.71 | 8307.5 | - | +| mxfp8_cublas | 33.88 | 9969.0 | +20.0% | +| mxfp8_cublas_rceil | 33.88 | 9642.0 | +16.1% | +| float8 tensorwise | 33.38 | 10417.0 | +25.4% | + +- pytorch version: `2.9.0.dev20250815+cu128` +- torchao version: `0.13.0+gite4e681be` +- torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30` + +*Source: [TorchAO MX Formats Benchmarks](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#training-e2e-benchmarks-on-nvidia-b200)* + +#### MoE models + +512 GPU training on 64 node GB200 cluster: + +| Scaling Method | Median tokens/s | Speedup over BF16 | +|------------------------|-----------------|-------------------| +| None (bfloat16) | 6169 | - | +| mxfp8 | 7401 | +20.3% | + +Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/). + +![MXFP8 vs BF16 Training Loss Curves](static/mxfp8_with_loss.png) + +*Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.* + +Training and model configurations for this run: +- Model: Llama4 Scout +- Dataset: C4 +- Sequence length: 8192 +- Local batch size: 10 +- Learning rate: 1e-4 +- LR scheduler warmup steps: 2000 +- Parallelisms (64 nodes of 4 devices each = 256 chips): + - FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts) + - EP=16 (on routed experts) +- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it). +- `torch.compile` enabled +- `mxfp8` applied to routed experts computation (grouped GEMMs) +- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8) + +### Composability +For distributed training, MXFP8 is compatible with: +- `torch.compile` +- FSDP2/TP/EP/PP +- Full activation checkpointing + +All distributed communication for MXFP8 training is currently done in high precision. + +### Known Limitations +- Currently in prototype stage - no BC guarantees. +- Requires torch nightly - important bug fixes have landed since 2.9.1 +- For GB200s, requires building torchao from source + +### Additional Resources + +- [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) - Blog post on accelerating dense model training with MXFP8 +- [TorchAO MX Formats Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats) +- [TorchAO MoE Training Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training) From 8d020ccb785aa2d2635e1cc57a25969b98992ca2 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 3 Dec 2025 14:31:45 -0800 Subject: [PATCH 045/127] Mark input tokens to routed experts as dynamic to avoid a recompile (#2007) Stacked PRs: * __->__#2007 --- --- --- Mark input tokens to routed experts as dynamic to avoid a recompile This saves 1 recompile, and you can see the input tokens are dynamic from the first graph compiled: ```python class GraphModule(torch.nn.Module): def forward(...s77: "Sym(s77)", L_x_: "bf16[s77, 5120][5120, 1]cuda:0"... ``` I verified that this also fixes the AC recompile issue of: https://github.com/pytorch/torchtitan/issues/1971. But I'm keeping `torch._C._dynamo.eval_frame._set_lru_cache(False)`, as there could be other recompile reasons popping up. --- .../models/deepseek_v3/infra/parallelize.py | 2 +- torchtitan/models/llama4/infra/parallelize.py | 20 +++++++++++++++++-- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index d6e7397645..69273654e3 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -121,7 +121,7 @@ def parallelize_deepseekv3( ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index c14741069f..28418d842e 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -130,7 +130,7 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: @@ -507,7 +507,7 @@ def apply_moe_ep_tp( ) -def apply_compile(model: nn.Module, compile_config: CompileConfig): +def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). @@ -578,6 +578,22 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): fullgraph=True, ) + if ep_enabled: + compiled_fn = moe_module._run_experts_grouped_mm + + def _run_experts_grouped_mm_dynamic( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + # dynamic number of tokens in expert parallel + torch._dynamo.mark_dynamic(x, 0) + return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) + + moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic + # NOTE: We don't compile for loop code path due to an issue with unbacked symints: # https://github.com/pytorch/pytorch/issues/166460 diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index c50783b582..12aca42777 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -123,7 +123,7 @@ def parallelize_qwen3( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel From 341b15517aa092856a26f0e9aacc20f9bdc26e5d Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 3 Dec 2025 14:44:21 -0800 Subject: [PATCH 046/127] fix mxfp8 loss image (#2104) In the original PR i moved the image location without updating the markdown pointing to it by accident. This fixes that. --- docs/mxfp8.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mxfp8.md b/docs/mxfp8.md index aae8f4a7c6..ad9f62ee3c 100644 --- a/docs/mxfp8.md +++ b/docs/mxfp8.md @@ -151,7 +151,7 @@ Single-node training on 8x power limited B200 GPUs, batch size 1, sequence lengt Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/). -![MXFP8 vs BF16 Training Loss Curves](static/mxfp8_with_loss.png) +![MXFP8 vs BF16 Training Loss Curves](../assets/images/mxfp8_with_loss.png) *Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.* From 1168f9e4d58bbd91c07b08c382d1ca3ae4b2e02c Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 4 Dec 2025 14:14:47 -0500 Subject: [PATCH 047/127] Update hf_assets_path for llama4 (#2110) Fix typo in train_config, hf asset should be for maverick, see: https://huggingface.co/meta-llama/models?search=128e --- torchtitan/models/llama4/train_configs/llama4_17bx128e.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml index fa0624bc8e..36d36712a2 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml @@ -17,7 +17,7 @@ save_tb_folder = "tb" [model] name = "llama4" flavor = "17bx128e" -hf_assets_path = "./assets/hf/Llama-4-Scout-17B-128E" +hf_assets_path = "./assets/hf/Llama-4-Maverick-17B-128E" # converters = ["float8"] [optimizer] From e98ae99536598f0320ee951863693178a3552257 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Fri, 5 Dec 2025 11:44:25 -0800 Subject: [PATCH 048/127] Enables parsing of --compile.components through CLI (#2115) Without this PR, I'm not able to pass `--compile.components=model,loss`. Tested using `python -m torchtitan.config.manager --compile.components=model,loss`. --- torchtitan/config/job_config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 612dc28101..c806041bb6 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -656,9 +656,7 @@ class Compile: enable: bool = False """Whether to apply torch.compile""" - components: list[Literal["model", "loss"]] = field( - default_factory=lambda: ["model", "loss"] - ) + components: list[str] = field(default_factory=lambda: ["model", "loss"]) """Which components to compile""" backend: str = "inductor" From 303f284c4a5aa2c9265543700dd10fad21712c3b Mon Sep 17 00:00:00 2001 From: Jiyue Wang Date: Sun, 7 Dec 2025 01:00:03 -0500 Subject: [PATCH 049/127] fix `ForgeEngine` compatibility issue with (#2121) Summary: Fix backward incompatible changes introduced in https://github.com/pytorch/torchtitan/commit/ff078526d1b9a51a3507cd234715ac3c61291e85 Differential Revision: D88572518 --- torchtitan/experiments/forge/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 2f1887b2d7..a4433ecef2 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -105,6 +105,7 @@ def __init__(self, job_config: ForgeJobConfig): world_mesh, self.device, job_config.debug, + distinct_seed_mesh_dims=["pp"], # same as `torchtitan/train.py` ) self.train_spec = get_train_spec(job_config.model.name) From 575674a4e2f3230e655bf384f9cfb0693753c4a0 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 8 Dec 2025 13:04:36 -0800 Subject: [PATCH 050/127] Remove the hack for SAC + FlexAttention (#2118) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2118 PyTorch can now support torch.compile inside the SAC region even if torch.compile is not used to wrap SAC. This PR removes the workaround to ensure torch.compile works with Flex --- .../unit_tests/test_activation_checkpoint.py | 11 -- .../distributed/activation_checkpoint.py | 108 +----------------- .../experiments/gpt_oss/infra/parallelize.py | 4 +- .../simple_fsdp/llama3/parallelize.py | 4 +- .../experiments/vlm/infra/parallelize.py | 2 - torchtitan/models/attention.py | 9 +- .../models/deepseek_v3/infra/parallelize.py | 3 +- torchtitan/models/llama3/infra/parallelize.py | 4 +- torchtitan/models/llama4/infra/parallelize.py | 4 +- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 10 files changed, 18 insertions(+), 133 deletions(-) diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index f309172173..2b05505e4a 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -88,7 +88,6 @@ def get_bw_flops(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_selective_ac = get_bw_flops(model_selective_ac) @@ -106,7 +105,6 @@ def get_bw_flops(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_first = get_bw_flops(model_with_force_first) @@ -123,7 +121,6 @@ def get_bw_flops(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_last = get_bw_flops(model_with_force_last) @@ -138,7 +135,6 @@ def get_bw_flops(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_full_ac = get_bw_flops(model_with_full_ac) @@ -181,7 +177,6 @@ def get_act_mem(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_selective_ac = get_act_mem(model_selective_ac) @@ -198,7 +193,6 @@ def get_act_mem(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_first = get_act_mem(model_with_force_first) @@ -214,7 +208,6 @@ def get_act_mem(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_last = get_act_mem(model_with_force_last) @@ -228,7 +221,6 @@ def get_act_mem(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_full_ac = get_act_mem(model_with_full_ac) @@ -255,7 +247,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) model_force_first = ToyModule() @@ -268,7 +259,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) @@ -282,7 +272,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 8359f71730..0eecde9052 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -17,7 +17,7 @@ ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.tools.logging import logger, warn_once +from torchtitan.tools.logging import logger _layer_sac_count = 0 @@ -155,88 +155,12 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ) -def _apply_op_sac_to_transformer_block_with_flex( - module: nn.Module, - ac_config: ACConfig, - *, - base_fqn: str | None = None, - model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload], -) -> nn.Module: - """Apply SAC to the transformer block that uses FlexAttention. - - Args: - module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The Activation Checkpoint config. - base_fqn (str, optional): The base fqn of the module. Defaults to None. - model_compile_enabled (bool): Whether model compilation is enabled. - Defaults to False. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. - - Returns: - nn.Module: The transformer block with SAC applied. - """ - - warn_once( - logger, - ( - "Flex Attention requires compilation for good performance.\n" - "Thus, torch.compile is always used for Flex Attention, " - "regardless of the compile.enable flag.\n" - "However, when selective activation checkpointing (SAC) is enabled, " - "torch.compile may be invalidated:\n" - "1. If compile.enable is False, SAC will ignore any torch.compile " - "inside the SAC region.\n" - "2. If compile.enable is True but the transformer block contains an MoE module.\n\n" - "For both cases, we will not wrap the entire TransformerBlock with SAC:\n" - " - For case 1: SAC will be used for MoE and FeedForward modules, " - "while full AC will be used for the Attention module.\n" - " - For case 2: SAC will be applied to MoE and Attention modules if the block " - "is sparse. But we still apply SAC to an entire dense block.\n" - ), - ) - - def wrap_submodule(name: str, full_ac: bool = False) -> None: - submodule = getattr(module, name) - if full_ac: - submodule = _apply_full_ac(submodule, ac_config) - else: - submodule = _apply_op_sac( - submodule, - ac_config, - base_fqn=f"{base_fqn}.{name}" if base_fqn else name, - op_sac_save_list=op_sac_save_list, - ) - module.register_module(name, submodule) - - if hasattr(module, "moe"): - wrap_submodule("moe", full_ac=False) - if model_compile_enabled: - wrap_submodule("attention", full_ac=False) - else: - wrap_submodule("attention", full_ac=True) - else: - if model_compile_enabled: - module = _apply_op_sac( - module, - ac_config, - base_fqn=base_fqn, - op_sac_save_list=op_sac_save_list, - ) - else: - wrap_submodule("feed_forward", full_ac=False) - wrap_submodule("attention", full_ac=True) - return module - - def _apply_ac_to_transformer_block( module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") @@ -259,26 +183,9 @@ def _apply_ac_to_transformer_block( if use_op_sac: op_sac_save_list = op_sac_save_list or set() - if use_flex_attn: - """ - For Flex Attention, we need to apply SAC carefully to avoid invalidating - torch.compile. Any torch.compile inside the SAC region will be ignored, - and any torch.compile outside the SAC region will also be ignored if the - SAC region contains a graph break (e.g., MoE). - - TODO: remove this once SAC issues are resolved. - """ - return _apply_op_sac_to_transformer_block_with_flex( - module, - ac_config, - base_fqn=base_fqn, - model_compile_enabled=model_compile_enabled, - op_sac_save_list=op_sac_save_list, - ) - else: - return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list - ) + return _apply_op_sac( + module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list + ) return _apply_layer_sac(module, ac_config) @@ -288,21 +195,15 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. - Note that SAC, Flex Attention and model compilation have some conflicts. - We explicitly ask the user to pass these configs to warn as the wrapping - will be different. - Args: model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - use_flex_attn (bool): Whether flex attention is enabled for the model. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. Returns: @@ -326,7 +227,6 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=op_sac_save_list, ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 232cba9ff7..4d1177d1ab 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -47,6 +47,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -110,14 +111,11 @@ def parallelize_gptoss( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, ) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index bd9c936b78..484d3d4747 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -34,6 +34,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn, + torch._higher_order_ops.inductor_compiled_code, } @@ -106,8 +107,6 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -115,7 +114,6 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index d418ad6edd..b6ada94d00 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -58,13 +58,11 @@ def parallelize_vlm( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, ) apply_ac(model.encoder, job_config.activation_checkpoint) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index cc7b87cb20..663ce54010 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -97,7 +97,14 @@ class FlexAttentionWrapper(torch.nn.Module): """ _compiled_flex_attn: ClassVar[Callable] = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" + flex_attention, + # This options also encapsulate max-autotune-no-cudagraphs. + options={ + "wrap_inductor_compiled_regions": True, + "max_autotune": True, + "coordinate_descent_tuning": True, + "triton.cudagraphs": False, + }, ) def forward( diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 69273654e3..98db56b135 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -44,6 +44,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -65,7 +66,6 @@ def parallelize_deepseekv3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": raise NotImplementedError("CP support is only supported for SDPA.") @@ -115,7 +115,6 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 1c381883b1..52a2dfe7e2 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -45,6 +45,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -95,14 +96,11 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 28418d842e..0b15e0c9eb 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -52,6 +52,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -116,14 +117,11 @@ def parallelize_llama( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 12aca42777..6bb9eb5204 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -47,6 +47,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -116,7 +117,6 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=attn_type == "flex", op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) From b41832a1d2b6d3a17824b994c18e26a6c5bc85cb Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 8 Dec 2025 19:07:50 -0500 Subject: [PATCH 051/127] Add warning to run_tests (#2123) Small addition since right now running a test that doesn't exist just outputs nothing, e.g. `python -m tests.integration_tests.run_tests ./test-out --test_name does_not_exist` Now the output is: ` WARNING:root:No tests were run for --test_name 'does_not_exist' in test suite 'features'. Available test names in 'features' suite: ['default', '1d_compile', '1d_compile_sac_op', '2d_eager', '2d_compile', 'full_checkpoint', 'model_only_hf_checkpoint', 'last_save_model_only_fp32', 'last_save_model_only_bf16', 'pp_looped_zero_bubble', 'pp_zbv', 'pp_1f1b', 'pp_gpipe', 'pp_dp_1f1b', 'pp_dp_gpipe', 'pp_tp', 'pp_dp_tp', '3d_compile', 'pp_looped_1f1b', 'pp_custom_csv', 'optimizer_foreach', 'ddp', 'hsdp', 'fsdp+flex_attn', 'fsdp+flex_attn+per_op_sac', 'fsdp+varlen_attn+per_op_sac', 'cp_allgather', 'cp_alltoall', 'hsdp+tp', 'fsdp+cp', 'hsdp+cp_without_dp_shard', 'hsdp+cp_with_dp_shard', 'fsdp+tp+cp', 'cpu_offload+opt_in_bwd+TP+DP+CP', 'test_generate', 'fsdp_reshard_always', 'optional_checkpoint', 'float8_emulation', 'gradient_accumulation', 'validation_tp_cp_pp'] ` --- tests/integration_tests/run_tests.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index c233904165..7081215c83 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -80,6 +80,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): args.config_path ), f"Base config path {args.config_path} does not exist" + ran_any_test = False for test_flavor in test_list: # Filter by test_name if specified if args.test_name != "all" and test_flavor.test_name != args.test_name: @@ -103,6 +104,14 @@ def run_tests(args, test_list: list[OverrideDefinitions]): ) else: run_single_test(test_flavor, args.config_path, args.output_dir) + ran_any_test = True + + if not ran_any_test: + available_tests = [t.test_name for t in test_list if not t.disabled] + logger.warning( + f"No tests were run for --test_name '{args.test_name}' in test suite '{args.test_suite}'.\n" + f"Available test names in '{args.test_suite}' suite: {available_tests}" + ) def main(): From d1924118f9bfe3b73f9aacb0cb7c954d88695db4 Mon Sep 17 00:00:00 2001 From: Yiming Zhou <61480007+yiming0416@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:30:56 -0800 Subject: [PATCH 052/127] [compiler toolkit] Disable CUDAGraph integration test (#2127) As titled. We'll enable when it is fixed. --- .../tests/integration_tests.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index f01a1c4380..9527d7dd23 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -100,21 +100,22 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--model.name compiler_toolkit.llama3", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--model.flavor debugmodel_flex_attn", - "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", - "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", - ], - ], - "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", - "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", - ngpu=4, - ), + # TODO: enable this when cudagraph is fixed + # OverrideDefinitions( + # [ + # [ + # "--model.name compiler_toolkit.llama3", + # "--parallelism.data_parallel_shard_degree 2", + # "--parallelism.tensor_parallel_degree 2", + # "--model.flavor debugmodel_flex_attn", + # "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + # "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", + # ], + # ], + # "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", + # "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", + # ngpu=4, + # ), OverrideDefinitions( [ [ From 1ebd914b06d9b0e05a5625a89f39cd96a7269c90 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 9 Dec 2025 10:12:43 -0800 Subject: [PATCH 053/127] Add CI for Autoparallel experiment llama3 on 4 GPUs (#2105) --- .../integration_test_8gpu_auto_parallel.yaml | 56 ++++++++++++ torchtitan/experiments/README.md | 2 +- .../auto_parallel/llama3/parallelize_llama.py | 2 +- .../auto_parallel/tests/__init__.py | 5 ++ .../auto_parallel/tests/integration_tests.py | 85 +++++++++++++++++++ 5 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/integration_test_8gpu_auto_parallel.yaml create mode 100644 torchtitan/experiments/auto_parallel/tests/__init__.py create mode 100644 torchtitan/experiments/auto_parallel/tests/integration_tests.py diff --git a/.github/workflows/integration_test_8gpu_auto_parallel.yaml b/.github/workflows/integration_test_8gpu_auto_parallel.yaml new file mode 100644 index 0000000000..85618aeeef --- /dev/null +++ b/.github/workflows/integration_test_8gpu_auto_parallel.yaml @@ -0,0 +1,56 @@ +name: Auto Parallel 8 GPU Integration Tests + +on: + push: + branches: [ main ] + paths: + - 'torchtitan/experiments/auto_parallel/**' + - '.github/workflows/integration_test_8gpu_auto_parallel.yaml' + pull_request: + paths: + - 'torchtitan/experiments/auto_parallel/**' + - '.github/workflows/integration_test_8gpu_auto_parallel.yaml' + schedule: + # Runs every 12 hours + - cron: '0 */12 * * *' + +concurrency: + group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-test: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.48xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + # This image is faster to clone than the default, but it lacks CC needed by triton + # (1m25s vs 2m37s). + docker-image: torchtitan-ubuntu-20.04-clang12 + repository: pytorch/torchtitan + upload-artifact: outputs + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) + echo "CUDA driver version: ${DRIVER_VERSION}" + + pip config --user set global.progress_bar off + + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + + # Install autoparallel - required dependency for auto_parallel experiment + python -m pip install git+https://github.com/meta-pytorch/autoparallel.git + + mkdir artifacts-to-be-uploaded + python -m torchtitan.experiments.auto_parallel.tests.integration_tests artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index aa93628656..5c1b20898d 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -32,4 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | -| [auto_parallel](./auto_parallel/) | TBA | [@wconstab](https://github.com/wconstab) | [@xmfan](https://github.com/xmfan) | +| [auto_parallel](./auto_parallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_auto_parallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_auto_parallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py index 1d2bee4351..d7fbae2622 100644 --- a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py @@ -126,7 +126,7 @@ def input_fn(): autop.add_input_constraints([x_sharding]) autop.add_output_constraints([out_sharding]) t0 = time.time() - sharding_placement = autop.optimize_placement() + sharding_placement = autop.optimize_placement(verbose=False) t1 = time.time() logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) diff --git a/torchtitan/experiments/auto_parallel/tests/__init__.py b/torchtitan/experiments/auto_parallel/tests/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/auto_parallel/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/auto_parallel/tests/integration_tests.py b/torchtitan/experiments/auto_parallel/tests/integration_tests.py new file mode 100644 index 0000000000..334aed86dd --- /dev/null +++ b/torchtitan/experiments/auto_parallel/tests/integration_tests.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_auto_parallel_test_list() -> list[OverrideDefinitions]: + """ + returns a list of OverrideDefinitions that is used to generate + variations of integration tests based on the same root config file. + """ + integration_tests_flavors = [ + # llama3 tests + OverrideDefinitions( + [ + [ + "--model.name auto_parallel.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.auto_parallel.job_config", + ], + ], + "llama3 AutoParallel FSDP+TP", + "llama3_autoparallel_fsdp_tp", + ngpu=4, + ), + # TODO: Re-enable this once we fix the test + # deepseek_v3 tests + # OverrideDefinitions( + # [ + # [ + # "--model.name auto_parallel.deepseek_v3", + # "--parallelism.data_parallel_shard_degree 2", + # "--parallelism.expert_parallel_degree 2", + # "--job.custom_config_module=torchtitan.experiments.auto_parallel.job_config", + # "--activation_checkpoint.mode none", + # ], + # ], + # "deepseek_v3 AutoParallel FSDP+TP+EP", + # "deepseekv3_autoparallel_fsdp_tp_ep", + # ngpu=4, + # ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "auto_parallel": build_auto_parallel_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["auto_parallel"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() From f1d41a1179a62e5274c565b080282ba7ccfeab2c Mon Sep 17 00:00:00 2001 From: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:25:42 -0800 Subject: [PATCH 054/127] Support rope cache indexing using positions (#2112) Add support to indexing rope cache using `position_ids`, this might be needed during 1. inference, where we passed in `position_ids` into transformer forward 2. CP load balancing where we need to index rope cache given positions ids Test: running dpskv3 16b base image also tested in https://github.com/wwwjn/torchtitan/pull/1/files when passing position_ids image --------- Co-authored-by: JessicaZhong --- .../models/deepseek_v3/infra/parallelize.py | 6 +- torchtitan/models/deepseek_v3/model/model.py | 74 +++++++++++++++++-- torchtitan/models/llama3/infra/parallelize.py | 6 +- torchtitan/models/llama3/model/model.py | 57 +++++++++++--- torchtitan/models/llama4/infra/parallelize.py | 6 +- torchtitan/models/llama4/model/model.py | 55 +++++++++++--- torchtitan/models/qwen3/infra/parallelize.py | 6 +- torchtitan/models/qwen3/model/model.py | 63 ++++++++++++---- 8 files changed, 226 insertions(+), 47 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 98db56b135..d66a30a83d 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -224,9 +224,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), + input_layouts=(Shard(1), Replicate(), None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 7d7635a4ad..5b17ad0acf 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -126,20 +126,71 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: return freqs_cis -def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Applies rotary positional embeddings to the input tensor. Args: x (torch.Tensor): Input tensor with positional embeddings to be applied. freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Tensor with rotary embeddings applied. """ dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + freqs_cis = reshape_for_broadcast(freqs_cis, x, positions) y = torch.view_as_real(x * freqs_cis).flatten(3) return y.to(dtype) @@ -196,6 +247,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -203,6 +255,8 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. @@ -222,7 +276,7 @@ def forward( q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) - q_pe = apply_rotary_emb(q_pe, freqs_cis) + q_pe = apply_rotary_emb(q_pe, freqs_cis, positions) q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) # Key-value projection @@ -230,7 +284,7 @@ def forward( kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb( - k_pe.unsqueeze(2), freqs_cis + k_pe.unsqueeze(2), freqs_cis, positions ) # (bsz, seqlen, 1, qk_rope_head_dim) kv = self.wkv_b( @@ -312,6 +366,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer block. @@ -319,11 +374,15 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + x = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: @@ -413,6 +472,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer model. @@ -422,6 +482,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). @@ -430,7 +492,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h = layer(h, self.freqs_cis, attention_masks, positions) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 52a2dfe7e2..13a968be96 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -205,9 +205,11 @@ def apply_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 74b862bf76..8982fcca9f 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -88,19 +88,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -108,16 +112,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -131,13 +154,14 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -213,6 +237,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -220,6 +245,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -236,7 +263,7 @@ def forward( xk = xk.view(bs, seqlen, -1, self.head_dim) xv = xv.view(bs, seqlen, -1, self.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -360,6 +387,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -367,12 +395,16 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -519,6 +551,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -528,6 +561,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -537,7 +572,9 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer( + h, self.freqs_cis, attention_masks=attention_masks, positions=positions + ) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 0b15e0c9eb..0fb2b54eac 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -240,9 +240,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 6b9d2d2d9e..7c4f073e19 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -86,19 +86,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -106,16 +110,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -129,13 +152,14 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -219,6 +243,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -226,6 +251,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -243,7 +270,7 @@ def forward( xv = xv.view(bs, seqlen, -1, self.head_dim) if self.use_rope: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -393,6 +420,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -400,12 +428,16 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -540,6 +572,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -549,6 +582,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -558,7 +593,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h = layer(h, self.freqs_cis, attention_masks, positions) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6bb9eb5204..517435714b 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -241,9 +241,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), + input_layouts=(Shard(1), Replicate(), None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index fa8fd454b1..62b5d0c381 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -57,7 +57,9 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. @@ -70,28 +72,51 @@ def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Te Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert ndim > 1 - _, seqlen, _, head_dim = x.shape - rope_cache = rope_cache[0:seqlen] - # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin - assert rope_cache.shape == (seqlen, head_dim * 2) - shape = [-1, seqlen, 1, head_dim * 2] - return rope_cache.view(*shape) + bz, seqlen, _, head_dim = x.shape + if positions is None: + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + rope_cache = rope_cache[positions.squeeze(0)] + # The shape of rope_cache is (seqlen, head_dim * 2) + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + else: + assert positions.shape == (bz, seqlen) + rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1) + rope_cache = torch.gather( + rope_cache_expanded, + dim=1, + index=positions.view(bz, seqlen, 1, 1).expand(bz, seqlen, 1, head_dim * 2), + ) + # The shape of rope_cache is (bz, seqlen, 1, head_dim * 2) + assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2) + return rope_cache def apply_rotary_emb( - xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor + xq: torch.Tensor, + xk: torch.Tensor, + rope_cache: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # input tensor x has shape [bsz, seq_len, num_heads, head_dim] head_dim = xq.shape[-1] - # reshape for broadcast - rope_cache = reshape_for_broadcast(rope_cache, xq) + rope_cache = reshape_for_broadcast(rope_cache, xq, positions) # [bsz, seq_len, 1, head_dim] cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) @@ -194,12 +219,16 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. Args: x (torch.Tensor): Input tensor. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -224,7 +253,7 @@ def forward( xk = self.k_norm(xk) # Apply rotary embedding - xq, xk = apply_rotary_emb(xq, xk, rope_cache) + xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -350,6 +379,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -357,12 +387,16 @@ def forward( Args: x (torch.Tensor): Input tensor. rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) + x = x + self.attention( + self.attention_norm(x), rope_cache, attention_masks, positions + ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -515,6 +549,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -524,6 +559,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -533,7 +570,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache, attention_masks) + h = layer(h, self.rope_cache, attention_masks, positions) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h From f3f2e8fd24193b2310294cc3dde38342c152d44d Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Tue, 9 Dec 2025 21:48:23 +0100 Subject: [PATCH 055/127] [forge] allow torchforges to set checkpoint base folder (#2131) this PR 1) allowing Torchforge to decide where to put the checkpoint and wandb, etc, instead of the "current" folder ~~allowing Torchforge to decide to print / log the configs~~ --- torchtitan/experiments/forge/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index a4433ecef2..5035129008 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -228,6 +228,7 @@ def __init__(self, job_config: ForgeJobConfig): if self.train_spec.state_dict_adapter else None ), + base_folder=job_config.job.dump_folder, ) loss_parallel_enabled = ( From fbafd44da2baef0afac58989f07d799c4251bdef Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 9 Dec 2025 13:24:12 -0800 Subject: [PATCH 056/127] Rename auto_parallel experiment to autoparallel (#2128) --- ...aml => integration_test_8gpu_autoparallel.yaml} | 12 ++++++------ torchtitan/experiments/README.md | 2 +- torchtitan/experiments/__init__.py | 4 ++-- .../{auto_parallel => autoparallel}/README.md | 4 ++-- .../deepseek_v3/__init__.py | 0 .../deepseek_v3/parallelize_deepseekv3.py | 2 +- .../{auto_parallel => autoparallel}/job_config.py | 2 +- .../llama3/__init__.py | 0 .../llama3/parallelize_llama.py | 0 .../tests/__init__.py | 0 .../tests/integration_tests.py | 14 +++++++------- 11 files changed, 20 insertions(+), 20 deletions(-) rename .github/workflows/{integration_test_8gpu_auto_parallel.yaml => integration_test_8gpu_autoparallel.yaml} (78%) rename torchtitan/experiments/{auto_parallel => autoparallel}/README.md (69%) rename torchtitan/experiments/{auto_parallel => autoparallel}/deepseek_v3/__init__.py (100%) rename torchtitan/experiments/{auto_parallel => autoparallel}/deepseek_v3/parallelize_deepseekv3.py (99%) rename torchtitan/experiments/{auto_parallel => autoparallel}/job_config.py (86%) rename torchtitan/experiments/{auto_parallel => autoparallel}/llama3/__init__.py (100%) rename torchtitan/experiments/{auto_parallel => autoparallel}/llama3/parallelize_llama.py (100%) rename torchtitan/experiments/{auto_parallel => autoparallel}/tests/__init__.py (100%) rename torchtitan/experiments/{auto_parallel => autoparallel}/tests/integration_tests.py (87%) diff --git a/.github/workflows/integration_test_8gpu_auto_parallel.yaml b/.github/workflows/integration_test_8gpu_autoparallel.yaml similarity index 78% rename from .github/workflows/integration_test_8gpu_auto_parallel.yaml rename to .github/workflows/integration_test_8gpu_autoparallel.yaml index 85618aeeef..220a346714 100644 --- a/.github/workflows/integration_test_8gpu_auto_parallel.yaml +++ b/.github/workflows/integration_test_8gpu_autoparallel.yaml @@ -4,12 +4,12 @@ on: push: branches: [ main ] paths: - - 'torchtitan/experiments/auto_parallel/**' - - '.github/workflows/integration_test_8gpu_auto_parallel.yaml' + - 'torchtitan/experiments/autoparallel/**' + - '.github/workflows/integration_test_8gpu_autoparallel.yaml' pull_request: paths: - - 'torchtitan/experiments/auto_parallel/**' - - '.github/workflows/integration_test_8gpu_auto_parallel.yaml' + - 'torchtitan/experiments/autoparallel/**' + - '.github/workflows/integration_test_8gpu_autoparallel.yaml' schedule: # Runs every 12 hours - cron: '0 */12 * * *' @@ -49,8 +49,8 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 - # Install autoparallel - required dependency for auto_parallel experiment + # Install autoparallel - required dependency for autoparallel experiment python -m pip install git+https://github.com/meta-pytorch/autoparallel.git mkdir artifacts-to-be-uploaded - python -m torchtitan.experiments.auto_parallel.tests.integration_tests artifacts-to-be-uploaded --ngpu 4 + python -m torchtitan.experiments.autoparallel.tests.integration_tests artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 5c1b20898d..6442ea71f8 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -32,4 +32,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | -| [auto_parallel](./auto_parallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_auto_parallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_auto_parallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | +| [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 7e2c442103..7d7f4da41a 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -13,7 +13,7 @@ "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", "transformers_modeling_backend", - "auto_parallel.llama3", - "auto_parallel.deepseek_v3", + "autoparallel.llama3", + "autoparallel.deepseek_v3", ] ) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/autoparallel/README.md similarity index 69% rename from torchtitan/experiments/auto_parallel/README.md rename to torchtitan/experiments/autoparallel/README.md index 55dcc3c5e5..3be86b9bc3 100644 --- a/torchtitan/experiments/auto_parallel/README.md +++ b/torchtitan/experiments/autoparallel/README.md @@ -12,8 +12,8 @@ Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://githu **Llama3** -`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` +`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` **DeepSeekv3** -`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name auto_parallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config` +`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/autoparallel/deepseek_v3/__init__.py similarity index 100% rename from torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py rename to torchtitan/experiments/autoparallel/deepseek_v3/__init__.py diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py similarity index 99% rename from torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py rename to torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index fc278cfabe..0f718a389b 100644 --- a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -258,7 +258,7 @@ def set_torchtitan_fields(orig, new): # Run workflow with: -# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_autoparallel def parallelize_deepseekv3( model, parallel_dims: ParallelDims, diff --git a/torchtitan/experiments/auto_parallel/job_config.py b/torchtitan/experiments/autoparallel/job_config.py similarity index 86% rename from torchtitan/experiments/auto_parallel/job_config.py rename to torchtitan/experiments/autoparallel/job_config.py index c880cadb31..b481318562 100644 --- a/torchtitan/experiments/auto_parallel/job_config.py +++ b/torchtitan/experiments/autoparallel/job_config.py @@ -8,7 +8,7 @@ """ -Use --job.custom_config_module=torchtitan.experiments.auto_parallel.job_config +Use --job.custom_config_module=torchtitan.experiments.autoparallel.job_config """ diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/autoparallel/llama3/__init__.py similarity index 100% rename from torchtitan/experiments/auto_parallel/llama3/__init__.py rename to torchtitan/experiments/autoparallel/llama3/__init__.py diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py similarity index 100% rename from torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py rename to torchtitan/experiments/autoparallel/llama3/parallelize_llama.py diff --git a/torchtitan/experiments/auto_parallel/tests/__init__.py b/torchtitan/experiments/autoparallel/tests/__init__.py similarity index 100% rename from torchtitan/experiments/auto_parallel/tests/__init__.py rename to torchtitan/experiments/autoparallel/tests/__init__.py diff --git a/torchtitan/experiments/auto_parallel/tests/integration_tests.py b/torchtitan/experiments/autoparallel/tests/integration_tests.py similarity index 87% rename from torchtitan/experiments/auto_parallel/tests/integration_tests.py rename to torchtitan/experiments/autoparallel/tests/integration_tests.py index 334aed86dd..8425d23254 100644 --- a/torchtitan/experiments/auto_parallel/tests/integration_tests.py +++ b/torchtitan/experiments/autoparallel/tests/integration_tests.py @@ -11,7 +11,7 @@ from tests.integration_tests.run_tests import run_tests -def build_auto_parallel_test_list() -> list[OverrideDefinitions]: +def build_autoparallel_test_list() -> list[OverrideDefinitions]: """ returns a list of OverrideDefinitions that is used to generate variations of integration tests based on the same root config file. @@ -21,10 +21,10 @@ def build_auto_parallel_test_list() -> list[OverrideDefinitions]: OverrideDefinitions( [ [ - "--model.name auto_parallel.llama3", + "--model.name autoparallel.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--job.custom_config_module=torchtitan.experiments.auto_parallel.job_config", + "--job.custom_config_module=torchtitan.experiments.autoparallel.job_config", ], ], "llama3 AutoParallel FSDP+TP", @@ -36,10 +36,10 @@ def build_auto_parallel_test_list() -> list[OverrideDefinitions]: # OverrideDefinitions( # [ # [ - # "--model.name auto_parallel.deepseek_v3", + # "--model.name autoparallel.deepseek_v3", # "--parallelism.data_parallel_shard_degree 2", # "--parallelism.expert_parallel_degree 2", - # "--job.custom_config_module=torchtitan.experiments.auto_parallel.job_config", + # "--job.custom_config_module=torchtitan.experiments.autoparallel.job_config", # "--activation_checkpoint.mode none", # ], # ], @@ -52,7 +52,7 @@ def build_auto_parallel_test_list() -> list[OverrideDefinitions]: _TEST_SUITES_FUNCTION = { - "auto_parallel": build_auto_parallel_test_list, + "autoparallel": build_autoparallel_test_list, } @@ -77,7 +77,7 @@ def main(): if os.listdir(args.output_dir): raise RuntimeError("Please provide an empty output directory.") - test_list = _TEST_SUITES_FUNCTION["auto_parallel"]() + test_list = _TEST_SUITES_FUNCTION["autoparallel"]() run_tests(args, test_list) From a632855c4e3770c7b99a97e9fd6f7b26a5206b44 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 11 Dec 2025 13:42:59 -0800 Subject: [PATCH 057/127] PyTorch depends on psutil (#2132) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2132 TorchTitan should also depends on psutil. --- .ci/docker/requirements.txt | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 9bf30b502c..89832abe65 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -8,3 +8,4 @@ fsspec tyro tokenizers >= 0.15.0 safetensors +psutil diff --git a/pyproject.toml b/pyproject.toml index 51d09420b4..efe74d3030 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "fsspec", "tyro", "tensorboard", + "psutil", ] dynamic = ["version"] From 4389efd06fc01a0141e08c51594bc85cac2235f1 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 11 Dec 2025 14:42:20 -0800 Subject: [PATCH 058/127] Remove caching for attention masks (#2117) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We remove the lru_cache for attention masks, because in get_attention_mask() function, `and_masks(*mask_mods)` will return different object id. `create_attention_mask` will use all parameters as cache key, and new object id will always cause cache miss. Before the change: (llama3 debugmodel_flex_attn) Screenshot 2025-12-09 at 1 27 45 PM After the change: Screenshot 2025-12-09 at 1 29 56 PM --- torchtitan/models/attention.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 663ce54010..819dbd57bc 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,7 +6,6 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -import functools from collections.abc import Callable from typing import ClassVar, NamedTuple @@ -171,22 +170,19 @@ def forward( return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) -# We cannot do inner function/closure because we won't be able to cache it -- -# if we an inner function, a new closure will be created every time -# `get_causal_mask_mod` is called. -def _causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor -) -> torch.Tensor: - """Causal mask that prevents attention to future tokens.""" - return q_idx >= kv_idx - - def get_causal_mask_mod() -> _mask_mod_signature: """Returns a causal mask modifier for flex attention. Returns: A mask modifier function that implements causal masking. """ + + def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + return _causal_mask @@ -275,13 +271,8 @@ def sliding_window_mod( _compiled_create_block_mask = torch.compile(create_block_mask) -@functools.lru_cache(4) def create_attention_mask(*args, **kwargs): - """Create an attention mask using compiled create_block_mask. - - This function is cached to avoid recreating BlockMasks for the same - arguments. - """ + """Create an attention mask using compiled create_block_mask.""" return _compiled_create_block_mask(*args, **kwargs) From 669845f434137401c79390ea03cf8f508ef2c73d Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Thu, 11 Dec 2025 18:33:51 -0800 Subject: [PATCH 059/127] Clarify contribution guidelines. (#2134) --- README.md | 9 +++++++-- torchtitan/experiments/README.md | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b68314f297..ec6ab30aad 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,14 @@ The Guiding Principles when building `torchtitan` * Minimal changes to the model code when applying multi-dimensional parallelism. * Bias towards a clean, minimal codebase while providing basic reusable / swappable components. -`torchtitan` has been showcasing PyTorch's latest distributed training features, via pretraining Llama 3.1 LLMs of various sizes. -To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. We look forward to your contributions! +`torchtitan` has been showcasing PyTorch's latest distributed training features, via support for pretraining Llama 3.1 LLMs of various sizes. +## Contributing + +We look forward to your contributions! + +* To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. New ideas should start there. To contribute, follow the [`experiments guidelines`](torchtitan/experiments/README.md). +* For fixes and contributions to core, follow these [`guidelines`](CONTRIBUTING.md). ## Llama 3.1 training diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 6442ea71f8..10b90ac1d4 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -10,7 +10,7 @@ We provide this `experiments/` folder to host experiments that add significant v 3. An experiment should reuse existing `torchtitan` code as much as possible, such as modules in [`components/`](../components/) (via a new [`TrainSpec`](../protocols/train_spec.py)) and [`train.py`](../train.py). For a list of extension points we provide, please refer to [docs/extension.md](../../docs/extension.md). - The extension points are subject to change. We kindly request that contributors provide feedback if they encounter issues reusing any components, rather than simply using a copy-and-paste approach. - The degree to which existing components are reused and whether duplications are legit will also be a criteria of whether an experiment would be accepted. -4. Each experiment is independent from other experiments, and can have its own dependencies (on top of [core dependencies](../../requirements.txt)), and its own tests. +4. Each experiment is independent from other experiments, and can have its own dependencies (on top of [core dependencies](../../requirements.txt)), and its own tests. An experiment should not contain vendor-specific code, such as kernels written in a proprietary language. Those can be hosted outside as dependency. 5. The dependency from `experiments` to `core` is one-way. Anything in `experiments` is optional for `core` to run successfully. In particular, development in `core` is not blocked by breakage in `experiments`. We will utilize GitHub's [CI mechanism](https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#onpushpull_requestpull_request_targetpathspaths-ignore) to help test an experiment periodically and only if the experiment itself is affected by a PR. 6. Each experiment needs to have an owner. The owner is responsible to work with `torchtitan` team to maintain the quality and healthiness of an experiment, which includes - adapting an experiment to changes in `core` and fix broken tests, no later than the next official `torchtitan` release; From fcc5643f4139a05f23d5d188fd5989324abb6bf3 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 12 Dec 2025 06:01:21 -0500 Subject: [PATCH 060/127] Enable PP and EP overlap for MoE (#1721) Option 2 of https://github.com/pytorch/torchtitan/issues/1682 These changes add a custom `overlap_callback` function to replace the OVERLAP_F_B action that is run during the schedule execution. In the custom function, we write `run_forward()` and `run_backward()`. `run_backward()` is run as a separate thread so that we can have both forward and backward running together side by side. Looks like this: image In order for these changes to work with Expert Parallel, we also need to add custom autograd functions to act as the boundary points at which we do communication. We added hooks before and after expert parallel dispatch and combine to signal boundary points, so our figure from before now turns into: image Now in each of these red blocks, we use a global coordinator. We need `threading.Barrier(2).wait()` so that the comm and compute from our forward and backward steps are scheduled in lock-step before continuing. DSv3 16B run command: ``` TORCH_NCCL_TRACE_BUFFER_SIZE=2000 TORCH_NCCL_DUMP_ON_TIMEOUT=true TORCH_FR_DUMP_TEMP_FILE=./nccl_trace_rank_ NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh ``` Trace examples: image Test command: `python -m tests.integration_tests.run_tests ./test-out --test_name pp_dualpipev --test_suite models` --------- Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> --- tests/integration_tests/models.py | 15 + torchtitan/config/job_config.py | 8 + torchtitan/distributed/dual_pipe_v.py | 309 ++++++++++++++++++ torchtitan/distributed/expert_parallel.py | 57 +++- torchtitan/distributed/pipeline_parallel.py | 7 + .../gpt_oss/infra/expert_parallel.py | 13 +- .../experiments/gpt_oss/infra/parallelize.py | 12 + .../models/deepseek_v3/infra/parallelize.py | 4 + torchtitan/models/llama4/infra/parallelize.py | 12 + torchtitan/models/qwen3/infra/parallelize.py | 4 + 10 files changed, 417 insertions(+), 24 deletions(-) create mode 100644 torchtitan/distributed/dual_pipe_v.py diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 37f588765b..606ecfe4bd 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -32,6 +32,21 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "deepseek_v3_fsdp+ep+compile", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name deepseek_v3", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.expert_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule DualPipeV", + # AC is not supported for DualPipeV yet + "--activation_checkpoint.mode 'none'", + ], + ], + "PP dual pipe v schedule test", + "pp_dualpipev", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index c806041bb6..5ef99f6934 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -373,6 +373,14 @@ class Parallelism: The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ + pipeline_parallel_expert_parallel_overlap: bool = True + """Whether to turn on the optimization to overlap expert parallel and pipeline parallel + communication. This is only effective when the pipeline parallel schedule is DualPipeV and + pipeline_parallel_degree > 1 and expert_parallel_degree > 1. + + TODO: Does not support activation_checkpoint, set mode="none" + """ + context_parallel_degree: int = 1 """Context parallelism degree. 1 means disabled.""" diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py new file mode 100644 index 0000000000..5a4a5d9dd0 --- /dev/null +++ b/torchtitan/distributed/dual_pipe_v.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import threading +from typing import cast, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from torch.distributed.pipelining.schedules import ( + _Action, + _PipelineContext, + _PipelineScheduleRuntime, + _wait_batch_p2p, +) +from torch.distributed.pipelining.stage import _PipelineStageBase +from torch.distributed.tensor import DeviceMesh, distribute_module +from torch.profiler import record_function + +from torchtitan.distributed.expert_parallel import BaseExpertParallel + +from torchtitan.tools.utils import get_device_info + +""" +Below are optimizations related to pipeline parallelism with expert parallelism +""" + + +def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool: + """ + Determine if DualPipeV should be enabled based on config and + validates that incompatible features (EP + DualPipeV + AC) are not used together. + """ + if not parallel_dims.ep_enabled or not parallel_dims.pp_enabled: + return False + + dual_pipe_v = ( + job_config.parallelism.pipeline_parallel_expert_parallel_overlap + and job_config.parallelism.pipeline_parallel_schedule.lower() == "dualpipev" + ) + + if dual_pipe_v and job_config.activation_checkpoint.mode != "none": + raise NotImplementedError( + "Expert Parallel with DualPipeV and Activation Checkpointing " + "cannot be used together. Please disable one of them." + ) + + return dual_pipe_v + + +class DualPipeExpertParallel(BaseExpertParallel): + """ + Wrapper that adds dual-pipe synchronization hooks to any BaseExpertParallel. + Wraps dispatch/combine with sync hooks for overlapping EP communication + with PP computation in DualPipe scheduling. + + The execution order becomes: + A -> dispatch -> B -> module -> C -> combine -> D + """ + + def __init__(self, inner_ep: BaseExpertParallel): + super().__init__() + self.inner_ep = inner_ep + + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + return self.inner_ep._partition_fn(name, mod, device_mesh) + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + """A -> dispatch -> B""" + inputs = (cast(Tensor, SyncHook.apply(inputs[0], "A")),) + inputs[1:] + outputs = self.inner_ep._token_dispatch(mod, inputs, device_mesh) + outputs = (cast(Tensor, SyncHook.apply(outputs[0], "B")),) + outputs[1:] + return outputs + + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + """C -> combine -> D""" + routed_output = cast(Tensor, SyncHook.apply(routed_output, "C")) + combine_output = self.inner_ep._token_combine(mod, routed_output, device_mesh) + combine_output = cast(Tensor, SyncHook.apply(combine_output, "D")) + return combine_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +class HookCoordinator: + def __init__(self): + # Barrier for 2 threads (forward and backward) to synchronize + # This ensures that we always alternate at executing one compute and one comm op together + self._execution_barrier = threading.Barrier(2) + + self._coordination_enabled = False + self._cycle_count = 0 + self._num_layers = None + + def barrier(self): + """Barrier for 2 threads to synchronize""" + if not self.is_coordination_enabled(): + return + + try: + self._execution_barrier.wait() + except threading.BrokenBarrierError: + pass + + def enable_coordination(self, num_layers: Optional[int] = None): + if num_layers is not None and num_layers > 0: + self._coordination_enabled = True + self._cycle_count = 0 + + # Reset barrier + self._execution_barrier = threading.Barrier(2) + self._num_layers = num_layers + + def disable_coordination(self): + self._coordination_enabled = False + self._cycle_count = 0 + self._execution_barrier.abort() # Break barrier to unblock threads + + def check_should_continue_coordination(self): + if self._num_layers is not None and self._cycle_count >= self._num_layers: + return False + return True + + def is_coordination_enabled(self): + return self._coordination_enabled + + +# Global coordinator +_hook_coordinator = HookCoordinator() + + +class SyncHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x, hook_name=""): + ctx.hook_name = hook_name + # handle edge case for transformer level boundary + if _hook_coordinator._coordination_enabled and hook_name == "D": + _hook_coordinator._cycle_count += 1 + if not _hook_coordinator.check_should_continue_coordination(): + _hook_coordinator.disable_coordination() + return x + + _hook_coordinator.barrier() + return x + + @staticmethod + def backward(ctx, grad_output): + hook_name = ctx.hook_name + + # Edge case, skip initial barrier, all subsequent backward hooks will acquire + if hook_name == "D" and _hook_coordinator._cycle_count == 0: + return grad_output, None + + _hook_coordinator.barrier() + return grad_output, None + + +def _count_moe_modules(model): + """Count MoE modules directly""" + from torchtitan.models.moe import MoE + + moe_count = 0 + for _, module in model.named_modules(): + if isinstance(module, MoE): + moe_count += 1 + return moe_count + + +device_type, device_module = get_device_info() + + +def overlap_callback(action: _Action, ctx: _PipelineContext): + """ + Custom callback for OVERLAP_F_B computation that allows expert parallel communication + and pipeline parallel computation to overlap. + """ + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in schedule._stages + } + assert action.sub_actions is not None + fwd_action = action.sub_actions[0] + bwd_action = action.sub_actions[1] + + # Get stages + forward_stage_index = fwd_action.stage_index + forward_mb_index = fwd_action.microbatch_index + assert forward_mb_index is not None + backward_stage_index = bwd_action.stage_index + backward_stage = stage_index_to_stage[backward_stage_index] + + # Forward setup + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + assert arg_mbs is not None and kwarg_mbs is not None + fwd_recv_ops = schedule.fwd_recv_ops + forward_stage = stage_index_to_stage[forward_stage_index] + forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage + forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage + + # Backward setup + backward_is_next_stage_on_this_rank = ( + backward_stage.stage_index + 1 in stage_index_to_stage + ) + backward_is_prev_stage_on_this_rank = ( + backward_stage.stage_index - 1 in stage_index_to_stage + ) + backward_mb_index = bwd_action.microbatch_index + assert backward_mb_index is not None + bwd_recv_ops = schedule.bwd_recv_ops + + # Fwd receives + if ( + not forward_stage.is_first + # no recv op expected for V-schedule special case + and not forward_is_prev_stage_on_this_rank + ): + assert ( + forward_stage_index, + forward_mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) + + # Bwd receives + if ( + not backward_stage.is_last + # no recv op expected for V-schedule special case + and not backward_is_next_stage_on_this_rank + ): + assert ( + backward_stage_index, + backward_mb_index, + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) + + # We count num layers in case the stage layers differ + # If they differ than we only want coordination to happen for the min amount of layers + min_num_layers = min( + _count_moe_modules(forward_stage.submod), + _count_moe_modules(backward_stage.submod), + ) + # PP computation ======================================================== + _hook_coordinator.enable_coordination(num_layers=min_num_layers) + main_stream = torch.accelerator.current_stream(device_module) + + # Shared container for exception from backward thread + def run_backward(): + schedule._assert_unsharded(backward_stage) + # Set the backward thread to use the same stream as forward + device_module.set_stream(main_stream) + with record_function( + f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" + ): + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + schedule.backward_counter[backward_stage_index] += 1 + last_backward = ( + schedule.backward_counter[backward_stage_index] + == schedule._n_microbatches + ) + backward_stage.backward_one_chunk( + backward_mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + + if backward_is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + backward_mb_index, + ) + + def run_forward(): + schedule._assert_unsharded(forward_stage) + output = forward_stage.forward_one_chunk( + forward_mb_index, + arg_mbs[forward_mb_index], + kwarg_mbs[forward_mb_index], + ) + schedule._maybe_compute_loss( + forward_stage, output, ctx.target_mbs, forward_mb_index + ) + if forward_is_next_stage_on_this_rank: + stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( + output, forward_mb_index + ) + + # Run forward and backward in parallel + thread = threading.Thread(target=run_backward, daemon=True) + thread.start() + run_forward() + thread.join() + + _hook_coordinator.disable_coordination() diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index b78019e057..932a7e4aa1 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -4,8 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod + import torch import torch.nn as nn +from torch import Tensor from torch.distributed._functional_collectives import ( all_to_all_single, all_to_all_single_autograd, @@ -24,6 +27,24 @@ from torchtitan.models.moe.utils import _permute, _unpermute +class BaseExpertParallel(ParallelStyle, ABC): + @abstractmethod + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + ... + + @abstractmethod + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + ... + + @abstractmethod + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + ... + + # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): def _prepare_input_fn(self, mod, inputs, device_mesh): @@ -64,7 +85,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) -class ExpertParallel(ParallelStyle): +class ExpertParallel(BaseExpertParallel): def __init__(self): super().__init__() self.input_splits = None @@ -72,8 +93,14 @@ def __init__(self): self.input_shape = None self.permuted_indices = None - # performing all-to-all dispatch on the input - def _token_dispatch(self, mod, inputs, device_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + for param_name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(param_name, dist_param) + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs ep_degree = device_mesh.shape[0] @@ -137,15 +164,9 @@ def _token_dispatch(self, mod, inputs, device_mesh): return routed_input, num_tokens_per_expert_group - @staticmethod - def _partition_fn(name, mod, device_mesh): - # shard on the expert dimension - for name, param in mod.named_parameters(recurse=False): - dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) - mod.register_parameter(name, dist_param) - - # performing all-to-all combine on the output - def _token_combine(self, mod, routed_output, device_mesh): + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: routed_output = _unpermute( routed_output, self.input_shape, self.permuted_indices ) @@ -162,7 +183,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, - partition_fn=ExpertParallel._partition_fn, + partition_fn=self._partition_fn, input_fn=self._token_dispatch, output_fn=self._token_combine, ) @@ -185,23 +206,23 @@ def _token_dispatch(self, mod, inputs, device_mesh): # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh return super()._token_dispatch(mod, inputs, device_mesh["ep"]) - def _partition_fn_2d(self, name, mod, ep_tp_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: # w1 shape = (experts, out_dim, in_dim) mod.register_parameter( "w1", - nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])), + nn.Parameter(distribute_tensor(mod.w1, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding # w2 shape = (experts, in_dim, out_dim) mod.register_parameter( "w2", - nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])), + nn.Parameter(distribute_tensor(mod.w2, device_mesh, [Shard(0), Shard(2)])), ) # Row-wise sharding # w3 shape = (experts, out_dim, in_dim) mod.register_parameter( "w3", - nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])), + nn.Parameter(distribute_tensor(mod.w3, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding def _token_combine(self, mod, routed_output, device_mesh): @@ -212,7 +233,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, - partition_fn=self._partition_fn_2d, + partition_fn=self._partition_fn, input_fn=self._token_dispatch, output_fn=self._token_combine, ) diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 06dba40d6f..bafefddbec 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -18,6 +18,7 @@ _PipelineSchedule, _PipelineScheduleRuntime, get_schedule_class, + OVERLAP_F_B, PipelineScheduleMulti, PipelineScheduleSingle, ScheduleDualPipeV, @@ -27,6 +28,7 @@ from torchtitan.components.loss import LossFunction, rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.dual_pipe_v import overlap_callback from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger @@ -209,6 +211,11 @@ def build_pipeline_schedule( f"with {n_microbatches} microbatches and {num_total_stages} stages." ) + if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance( + schedule, ScheduleDualPipeV + ): + schedule.register_custom_function(OVERLAP_F_B, overlap_callback) + if pp_schedule_csv: assert schedule_class in [ PipelineScheduleSingle, diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py index 96ad157c2f..1e8054a481 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -6,9 +6,10 @@ import torch.nn as nn -from torch.distributed.tensor import distribute_tensor, Replicate, Shard +from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel + # implementation of Tensor Parallel for the GroupedExperts in MoE class GptossTensorParallel(TensorParallel): def _partition_fn(self, name, module, device_mesh): @@ -38,28 +39,28 @@ def _partition_fn(self, name, module, device_mesh): # This class is for dp2ep with TP (without TP we can just use GptossExpertParallel) class GptossExpertTensorParallel(ExpertTensorParallel): - def _partition_fn_2d(self, name, mod, ep_tp_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: mod.register_parameter( "mlp1_weight", nn.Parameter( - distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)]) + distribute_tensor(mod.mlp1_weight, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp1_bias", nn.Parameter( - distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) + distribute_tensor(mod.mlp1_bias, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp2_weight", nn.Parameter( - distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)]) + distribute_tensor(mod.mlp2_weight, device_mesh, [Shard(0), Shard(2)]) ), ) # Row-wise sharding mod.register_parameter( "mlp2_bias", nn.Parameter( - distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) + distribute_tensor(mod.mlp2_bias, device_mesh, [Shard(0), Replicate()]) ), ) # Replicate diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 4d1177d1ab..9b2e75ac4f 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -21,7 +21,12 @@ from torchtitan.config.job_config import JobConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import ( + DualPipeExpertParallel, + get_dual_pipe_v_flag, +) from torchtitan.distributed.expert_parallel import ( + BaseExpertParallel, ExpertParallel, ReordererSequenceParallel, ) @@ -93,6 +98,8 @@ def parallelize_gptoss( ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, @@ -105,6 +112,7 @@ def parallelize_gptoss( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=dual_pipe_v, ) model_compile_enabled = ( @@ -257,6 +265,7 @@ def apply_moe_ep_tp( ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, etp_enabled: bool, + dual_pipe_v: bool = False, ): assert ep_mesh is not None or tp_mesh is not None @@ -303,6 +312,9 @@ def apply_moe_ep_tp( experts_mesh = ep_tp_mesh experts_plan = GptossExpertTensorParallel() + if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): + experts_plan = DualPipeExpertParallel(experts_plan) + parallelize_module( module=transformer_block.moe.experts, device_mesh=experts_mesh, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index d66a30a83d..c068e60a30 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -19,6 +19,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( @@ -92,6 +93,8 @@ def parallelize_deepseekv3( maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, @@ -104,6 +107,7 @@ def parallelize_deepseekv3( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=dual_pipe_v, ) model_compile_enabled = ( diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 0fb2b54eac..7440b3c3f5 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -24,8 +24,13 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import ( + DualPipeExpertParallel, + get_dual_pipe_v_flag, +) from torchtitan.distributed.expert_parallel import ( + BaseExpertParallel, ExpertParallel, ExpertTensorParallel, ReordererSequenceParallel, @@ -100,6 +105,8 @@ def parallelize_llama( maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, @@ -112,6 +119,7 @@ def parallelize_llama( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=dual_pipe_v, ) model_compile_enabled = ( @@ -444,6 +452,7 @@ def apply_moe_ep_tp( ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, etp_enabled: bool, + dual_pipe_v: bool = False, ): assert ep_mesh is not None or tp_mesh is not None @@ -500,6 +509,9 @@ def apply_moe_ep_tp( experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel() + if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): + experts_plan = DualPipeExpertParallel(experts_plan) + parallelize_module( module=transformer_block.moe.experts, device_mesh=experts_mesh, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 517435714b..5f9f0a73be 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -23,6 +23,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( apply_compile, @@ -98,6 +99,8 @@ def parallelize_qwen3( ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, @@ -110,6 +113,7 @@ def parallelize_qwen3( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=dual_pipe_v, ) if job_config.activation_checkpoint.mode != "none": From 7a398eabcde66d890746e8f63dda3b47433f9a25 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 12 Dec 2025 13:24:19 -0800 Subject: [PATCH 061/127] Fix apply_compile called multiple times in PP initialization (#2135) Stacked PRs: * __->__#2135 --- --- --- PP initialization calls apply_compile multiple times, once per pp stage. But apply_compile does some global patching. So I add `already_patched` to avoid patching the same method multiple times. If we patch multiple times, the second time will wrap `_run_experts_grouped_mm_dynamic` in a torch.compile(fullgraph=True) leading to the error in the issue below. FIXES https://github.com/pytorch/torchtitan/issues/2124 --- tests/unit_tests/test_compile_moe.py | 81 +++++++++++++++++++ torchtitan/models/llama4/infra/parallelize.py | 45 ++++++----- 2 files changed, 107 insertions(+), 19 deletions(-) create mode 100644 tests/unit_tests/test_compile_moe.py diff --git a/tests/unit_tests/test_compile_moe.py b/tests/unit_tests/test_compile_moe.py new file mode 100644 index 0000000000..52a6b99ef5 --- /dev/null +++ b/tests/unit_tests/test_compile_moe.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn + +from torchtitan.config.job_config import Compile as CompileConfig +from torchtitan.models.llama4.infra.parallelize import apply_compile + + +class TransformerBlock(nn.Module): + def __init__(self, dim=512): + super().__init__() + self.attention = nn.Linear(dim, dim, bias=False) + self.mlp = nn.Linear(dim, dim, bias=False) + self.moe_enabled = False + + def forward(self, x): + x = self.attention(x) + x = self.mlp(x) + return x + + +class TinyModel(nn.Module): + def __init__(self, num_layers=2, dim=512): + super().__init__() + self.layers = nn.ModuleDict( + {str(i): TransformerBlock(dim) for i in range(num_layers)} + ) + + def forward(self, x): + for layer in self.layers.values(): + x = layer(x) + return x + + +class TestApplyCompile(unittest.TestCase): + def test_patched_once(self): + """ + Calls apply_compile multiple times, as in the case with PP. + But patches should only happen once + """ + unused_model1 = TinyModel(num_layers=2, dim=128) + unused_model2 = TinyModel(num_layers=2, dim=128) + compile_config = CompileConfig(backend="eager") + + apply_compile(unused_model1, compile_config, ep_enabled=True) + apply_compile(unused_model2, compile_config, ep_enabled=True) + + from torchtitan.models.moe import moe as moe_module + + # Generate sample inputs for _run_experts_grouped_mm + num_experts = 8 + dim = 128 + hidden_dim = 256 + w1 = torch.randn(num_experts, hidden_dim, dim) + w2 = torch.randn(num_experts, dim, hidden_dim) + w3 = torch.randn(num_experts, hidden_dim, dim) + num_tokens_per_expert = torch.tensor( + [10, 8, 12, 9, 11, 7, 10, 13], dtype=torch.int32 + ) + total_tokens = num_tokens_per_expert.sum().item() + x = torch.randn(total_tokens, dim) + + # Call the function, should not error + output = moe_module._run_experts_grouped_mm( + w1, w2, w3, x, num_tokens_per_expert + ) + + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + print(f"Num tokens per expert: {num_tokens_per_expert}") + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 7440b3c3f5..b8b3470d37 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -584,27 +584,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b model.layers.register_module(layer_id, transformer_block) - moe_module._run_experts_grouped_mm = torch.compile( - moe_module._run_experts_grouped_mm, - backend=compile_config.backend, - fullgraph=True, + # Patch some globals only once (apply_compile is called multiple times for PP setup) + already_patched = ( + "_run_experts_grouped_mm_dynamic" + in moe_module._run_experts_grouped_mm.__qualname__ ) + if not already_patched: + moe_module._run_experts_grouped_mm = torch.compile( + moe_module._run_experts_grouped_mm, + backend=compile_config.backend, + fullgraph=True, + ) - if ep_enabled: - compiled_fn = moe_module._run_experts_grouped_mm - - def _run_experts_grouped_mm_dynamic( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, - ) -> torch.Tensor: - # dynamic number of tokens in expert parallel - torch._dynamo.mark_dynamic(x, 0) - return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) - - moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic + if ep_enabled: + compiled_fn = moe_module._run_experts_grouped_mm + + # keep function logic in sync with `already_patched` above + def _run_experts_grouped_mm_dynamic( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + # dynamic number of tokens in expert parallel + torch._dynamo.mark_dynamic(x, 0) + return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) + + moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic # NOTE: We don't compile for loop code path due to an issue with unbacked symints: # https://github.com/pytorch/pytorch/issues/166460 From 64dc922fc39ed60f1f5997e7866026b6f67b4ec0 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Fri, 12 Dec 2025 15:26:25 -0800 Subject: [PATCH 062/127] Enable static type checking with Pyrefly (#2136) Enables static type checking of torchtitan with [pyrefly](https://github.com/facebook/pyrefly). Type checking the code helps catch bugs earlier in the development cycle. * Adds pyrefly to CI, as part of the linting workflow. * Addresses ~100 type errors that can be fixed via local code changes and updates to type annotations, and silences the rest with `# pyrefly: ignore` suppression comments. Note that https://github.com/pytorch/torchtitan/commit/325efd946f1cbea85e503f9e684b8c879891fc1a contains all of the non-comment changes. --- .ci/docker/requirements-dev.txt | 1 + .ci/docker/requirements-flux.txt | 2 - .ci/docker/requirements.txt | 2 + .github/workflows/lint.yaml | 3 +- .pre-commit-config.yaml | 8 ++++ CONTRIBUTING.md | 2 +- pyproject.toml | 6 +++ .../checkpoint_conversion/convert_from_hf.py | 4 +- .../checkpoint_conversion/convert_to_hf.py | 2 + .../numerical_tests_example.py | 10 ++++- scripts/download_hf_assets.py | 1 + scripts/estimate/estimation.py | 22 +++++++++- scripts/generate/test_generate.py | 9 ++++ scripts/loss_compare.py | 1 + torchtitan/components/checkpoint.py | 31 ++++++++++++-- torchtitan/components/dataloader.py | 3 +- torchtitan/components/ft/manager.py | 3 +- torchtitan/components/lr_scheduler.py | 2 + torchtitan/components/metrics.py | 10 ++++- torchtitan/components/optimizer.py | 14 +++++++ .../components/quantization/__init__.py | 2 +- torchtitan/components/quantization/float8.py | 1 + torchtitan/components/quantization/mx.py | 5 +++ torchtitan/components/tokenizer.py | 4 ++ torchtitan/components/validate.py | 20 ++++++--- torchtitan/config/manager.py | 4 ++ torchtitan/distributed/__init__.py | 5 ++- .../distributed/activation_checkpoint.py | 3 ++ torchtitan/distributed/dual_pipe_v.py | 18 +++++++- torchtitan/distributed/expert_parallel.py | 10 +++++ torchtitan/distributed/parallel_dims.py | 6 +-- torchtitan/distributed/pipeline_parallel.py | 7 +++- torchtitan/distributed/tensor_parallel.py | 1 + torchtitan/distributed/utils.py | 41 +++++++++++++++---- torchtitan/hf_datasets/text_datasets.py | 2 +- torchtitan/models/attention.py | 10 +++-- .../models/deepseek_v3/infra/parallelize.py | 6 +++ torchtitan/models/deepseek_v3/model/args.py | 9 ++-- torchtitan/models/deepseek_v3/model/model.py | 2 + .../deepseek_v3/model/state_dict_adapter.py | 9 +++- torchtitan/models/flux/__init__.py | 1 + torchtitan/models/flux/flux_datasets.py | 4 +- torchtitan/models/flux/inference/infer.py | 4 ++ torchtitan/models/flux/inference/sampling.py | 22 +++++++++- torchtitan/models/flux/infra/parallelize.py | 17 +++++++- torchtitan/models/flux/model/autoencoder.py | 18 +++++++- torchtitan/models/flux/model/hf_embedder.py | 2 + torchtitan/models/flux/model/layers.py | 7 +++- torchtitan/models/flux/model/model.py | 6 ++- .../models/flux/model/state_dict_adapter.py | 4 ++ torchtitan/models/flux/tokenizer.py | 10 +++++ torchtitan/models/flux/train.py | 10 +++++ torchtitan/models/flux/validate.py | 22 +++++++--- torchtitan/models/llama3/infra/parallelize.py | 11 +++++ torchtitan/models/llama3/model/args.py | 4 +- torchtitan/models/llama3/model/model.py | 6 +++ .../models/llama3/model/state_dict_adapter.py | 3 ++ torchtitan/models/llama4/infra/parallelize.py | 39 +++++++++++++++++- torchtitan/models/llama4/model/args.py | 4 +- torchtitan/models/llama4/model/model.py | 5 +++ .../models/llama4/model/state_dict_adapter.py | 4 ++ torchtitan/models/moe/kernels.py | 3 ++ torchtitan/models/moe/moe.py | 15 ++++--- torchtitan/models/qwen3/infra/parallelize.py | 8 ++++ torchtitan/models/qwen3/model/args.py | 4 +- torchtitan/models/qwen3/model/model.py | 9 ++++ .../models/qwen3/model/state_dict_adapter.py | 14 ++++++- torchtitan/models/utils.py | 16 ++++++-- torchtitan/protocols/model.py | 4 +- torchtitan/protocols/state_dict_adapter.py | 5 ++- torchtitan/tools/profiling.py | 1 + torchtitan/tools/utils.py | 2 +- torchtitan/train.py | 20 ++++++++- 73 files changed, 516 insertions(+), 89 deletions(-) diff --git a/.ci/docker/requirements-dev.txt b/.ci/docker/requirements-dev.txt index 6d53b2f817..0e5a6e491c 100644 --- a/.ci/docker/requirements-dev.txt +++ b/.ci/docker/requirements-dev.txt @@ -2,5 +2,6 @@ expecttest==0.1.6 pytest==7.3.2 pytest-cov pre-commit +pyrefly==0.45.1 tomli-w >= 1.1.0 transformers diff --git a/.ci/docker/requirements-flux.txt b/.ci/docker/requirements-flux.txt index daefd67ff0..8d6797a36b 100644 --- a/.ci/docker/requirements-flux.txt +++ b/.ci/docker/requirements-flux.txt @@ -1,4 +1,2 @@ transformers>=4.51.1 -einops sentencepiece -pillow diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 89832abe65..b63653bb53 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -9,3 +9,5 @@ tyro tokenizers >= 0.15.0 safetensors psutil +einops +pillow diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 0a3976248f..327b0bec23 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -28,7 +28,8 @@ jobs: run: python -m pip install --upgrade pip - name: Install lint utilities run: | - python -m pip install pre-commit + python -m pip install -r requirements.txt -r requirements-dev.txt + python -m pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch pre-commit install-hooks - name: Get changed files id: changed-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc996e5046..6f8542fab4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,3 +61,11 @@ repos: types: [text] additional_dependencies: - tomli + +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 0.45.1 + hooks: + - id: pyrefly-check + name: Pyrefly (type checking) + pass_filenames: false + language: system diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8de2b9df9d..de6373236a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin ### Setup ``` -pip install -r requirements-dev.txt +pip install -r requirements.txt -r requirements-dev.txt ``` ### Pull Requests diff --git a/pyproject.toml b/pyproject.toml index efe74d3030..7a3687590c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "tyro", "tensorboard", "psutil", + "einops", + "pillow", ] dynamic = ["version"] @@ -62,3 +64,7 @@ include = ["torchtitan*"] [tool.pytest.ini_options] addopts = ["--showlocals"] # show local variables in tracebacks testpaths = ["tests"] + +[tool.pyrefly] +project-excludes = ["torchtitan/experiments", "**/tests/**"] +ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index fae7eec17b..77bfeddd59 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -16,16 +16,16 @@ @torch.inference_mode() def convert_from_hf(input_dir, output_dir, model_name, model_flavor): - if model_name == "flux": - import torchtitan.experiments.flux # noqa: F401 # initialize model to allocate memory for state dict train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] with torch.device("cpu"): model = train_spec.model_cls(model_args) + # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) + # pyrefly: ignore [not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, None) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index ad13850b82..e68a6d2acc 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -30,8 +30,10 @@ def convert_to_hf( with torch.device("cpu"): model = train_spec.model_cls(model_args) + # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) + # pyrefly: ignore [not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py index 66eff8054e..f52851ef9b 100644 --- a/scripts/checkpoint_conversion/numerical_tests_example.py +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -25,7 +25,7 @@ def loss_fn(logits1, logits2): probs2 = F.softmax(logits2, dim=-1) # Calculate KL Divergence - kl_loss = F.kl_div(probs1, probs2, "mean") + kl_loss = F.kl_div(probs1, probs2, reduction="mean") return kl_loss @@ -75,10 +75,13 @@ def forward_tt(config_path, checkpoint_path, test_set): # materalize model device = torch.device(device_type) + # pyrefly: ignore [missing-attribute] model.to_empty(device=device) model.init_weights(buffer_device=device) + # pyrefly: ignore [missing-attribute] model.eval() + # pyrefly: ignore [bad-argument-type] modelWrapper = ModelWrapper(model) state_dict = modelWrapper._get_state_dict() @@ -94,6 +97,7 @@ def forward_tt(config_path, checkpoint_path, test_set): input_ids = input_ids.unsqueeze(0) # obtains the logits of only the last token in the predictions + # pyrefly: ignore [not-callable] predictions = model(input_ids)[:, -1, :].unsqueeze(1) output_list.append(predictions) @@ -120,6 +124,7 @@ def forward_tt(config_path, checkpoint_path, test_set): config_manager = ConfigManager() config = config_manager.parse_args([f"--job.config_file={config_path}"]) train_spec = get_train_spec(config.model.name) + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) # Build test set of randomly generated token ids @@ -150,10 +155,11 @@ def forward_tt(config_path, checkpoint_path, test_set): avg_losses = {} for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): - total_loss = 0 + total_loss: int | torch.Tensor = 0 for baseline, outputs in zip(baseline_outputs, conversion_outputs): total_loss += loss_fn(baseline, outputs) avg_loss = total_loss / len(test_set) + # pyrefly: ignore [missing-attribute] avg_losses[test_name] = avg_loss.item() for test_name, avg_loss in avg_losses.items(): diff --git a/scripts/download_hf_assets.py b/scripts/download_hf_assets.py index e1092b2d70..dbe8ba98b6 100644 --- a/scripts/download_hf_assets.py +++ b/scripts/download_hf_assets.py @@ -167,6 +167,7 @@ def should_download(patterns: list[str], filename: str) -> bool: missed_files = [] # Download files with progress bar + # pyrefly: ignore [bad-context-manager] with tqdm(total=len(files_found), desc="Downloading files", unit="file") as pbar: for filename in files_found: try: diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index e0a752d545..bfa9dddfd2 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -98,44 +98,58 @@ def estimate_memory(job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) + # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing train_spec.parallelize_fn(model, parallel_dims, job_config) + # pyrefly: ignore [missing-attribute] model.to_empty(device="cuda") if not active_fake_mode(): model.init_weights() + # pyrefly: ignore [missing-attribute] model.train() # build optimizer after applying parallelisms to the model + # pyrefly: ignore [bad-argument-type] optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) lr_schedulers = build_lr_schedulers( - optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps + # pyrefly: ignore [bad-argument-type] + optimizers.optimizers, + job_config.lr_scheduler, + job_config.training.steps, ) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance optimizers.register_step_post_hook( + # pyrefly: ignore [bad-argument-type] lambda *args, **kwargs: model_converters.post_optimizer_hook(model) ) + # pyrefly: ignore [missing-attribute] logger.info(f"Vocab size: {model_args.vocab_size}") # Create a dummy batch instead of loading from a dataset batch = ( torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), ) + # pyrefly: ignore [bad-argument-type] fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) @@ -145,6 +159,7 @@ def estimate_memory(job_config: JobConfig): input_ids, labels = batch # train step with train_context(): + # pyrefly: ignore [not-callable] pred = model(input_ids) loss = loss_fn(pred, labels) del pred @@ -152,7 +167,10 @@ def estimate_memory(job_config: JobConfig): # clip gradients torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True + # pyrefly: ignore [missing-attribute] + model.parameters(), + job_config.training.max_norm, + foreach=True, ) # optimizer step optimizers.step() diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index b1d19ad17f..bff9c2aa7f 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -36,6 +36,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +# pyrefly: ignore [missing-import] from generate._generation import generate @@ -49,6 +50,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): }, ) + # pyrefly: ignore [missing-attribute] for _, transformer_block in model.layers.items(): layer_plan = { "attention.wq": ColwiseParallel(), @@ -63,6 +65,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): parallelize_module( module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -95,6 +98,7 @@ def test_generate( world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"{device_type}:{local_rank}") + # pyrefly: ignore [missing-attribute] device_module.set_device(device) device_memory_monitor = build_device_memory_monitor() @@ -103,6 +107,7 @@ def test_generate( logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") # Tokenizer setup + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) model_args = train_spec.model_args[config.model.flavor] @@ -131,6 +136,7 @@ def test_generate( # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 + # pyrefly: ignore [bad-argument-type] apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) debug_config = DebugConfig(seed=seed, deterministic=deterministic) @@ -142,11 +148,14 @@ def test_generate( ) # materalize model + # pyrefly: ignore [missing-attribute] model.to_empty(device=device_type) with torch.no_grad(): model.init_weights() + # pyrefly: ignore [missing-attribute] model.eval() + # pyrefly: ignore [missing-attribute] state_dict = model.state_dict() # Checkpoint Loading diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 3479875036..e9761458a8 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -134,6 +134,7 @@ def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> Non bufsize=1, ) + # pyrefly: ignore [not-iterable] for line in process.stdout: print(line, end="") log_f.write(line) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 79918d0046..7928f514ba 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -29,7 +29,10 @@ set_model_state_dict, StateDictOptions, ) -from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType +from torch.distributed.checkpoint.state_dict_saver import ( + AsyncCheckpointerType, + AsyncSaveResponse, +) from torch.distributed.checkpoint.stateful import Stateful from torchtitan.components.dataloader import BaseDataLoader @@ -174,6 +177,9 @@ class CheckpointManager: """ + mp_queue_send: queue.Queue + purge_thread: threading.Thread | None + def __init__( self, dataloader: BaseDataLoader | None, @@ -208,12 +214,14 @@ def __init__( ) if self.ft_manager and not self.enable_ft_dataloader_checkpoints: + # pyrefly: ignore [deprecated] logger.warn( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." ) if self.ft_manager: + # pyrefly: ignore [missing-attribute] optimizers.init_cache_state_dict() def state_dict(): @@ -233,7 +241,9 @@ def load_state_dict(state_dict): for k, v in state_dict.items(): self.states[k].load_state_dict(v) + # pyrefly: ignore [missing-attribute] self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) + # pyrefly: ignore [missing-attribute] self.ft_replica_id = ft_manager.replica_id async_mode = checkpoint_config.async_mode.lower() @@ -344,7 +354,7 @@ def dcp_save( async_mode: AsyncMode, enable_garbage_collection: bool = False, to_hf: bool = False, - ) -> Future | None: + ) -> Future | AsyncSaveResponse | None: """Save the checkpoint with dcp. Args: state_dict (dict): The state dict to save. @@ -357,7 +367,7 @@ def dcp_save( Future: The future object if the checkpoint is async, otherwise None. """ - ret: Future | None = None + ret: Future | AsyncSaveResponse | None = None storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None @@ -394,6 +404,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: @@ -401,6 +412,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, async_checkpointer_type=AsyncCheckpointerType.PROCESS, async_stager=self.stager, @@ -412,10 +424,12 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + # pyrefly: ignore [missing-attribute] if to_hf and self.sd_adapter.fqn_to_index_mapping: consolidate_safetensors_files_on_every_rank( input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, + # pyrefly: ignore [bad-argument-type] fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, num_threads=5, ) @@ -489,7 +503,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: begin = time.monotonic() if not self.enable_ft_dataloader_checkpoints or ( - self.ft_manager and self.ft_manager.participating_rank() == 0 + self.ft_manager + # pyrefly: ignore [missing-attribute] + and self.ft_manager.participating_rank() == 0 ): logger.info("Saving the checkpoint (or staging if async is enabled).") checkpoint_id = self._create_checkpoint_id(curr_step) @@ -511,7 +527,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: checkpoint_id=checkpoint_id, async_mode=self.async_mode, ) + # pyrefly: ignore [missing-attribute] self.save_future = result.upload_completion + # pyrefly: ignore [missing-attribute] self.staging_future = result.staging_completion self.staging = True elif self.async_mode == AsyncMode.ASYNC: @@ -537,6 +555,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None: assert self.ft_manager is not None logger.info( "Replica %d doesn't save checkpoint.", + # pyrefly: ignore [missing-attribute] self.ft_manager.participating_rank(), ) @@ -589,6 +608,7 @@ def load(self, step: int = -1) -> bool: f"loading from HF safetensors from --checkpoint.initial_load_path: {self.initial_load_path}" ) elif from_hf: + # pyrefly: ignore [missing-attribute] checkpoint_id = self.sd_adapter.hf_assets_path if not os.path.isdir(checkpoint_id): raise ValueError( @@ -596,6 +616,7 @@ def load(self, step: int = -1) -> bool: Either make sure hf_assets_path is correct or provide a valid checkpoint.initial_load_path" ) logger.info( + # pyrefly: ignore [missing-attribute] f"loading HF safetensors from --model.hf_assets_path: {self.sd_adapter.hf_assets_path}" ) else: @@ -644,6 +665,7 @@ def maybe_wait_for_staging(self) -> None: with ``async_checkpoint_with_pinned_memory``. """ if self.enable_staging and self.staging: + # pyrefly: ignore [missing-attribute] self.staging_future.result() self.staging = False @@ -828,6 +850,7 @@ def _purge_stale_checkpoints(self): and os.path.isdir(self.folder) and ( not self.enable_ft_dataloader_checkpoints + # pyrefly: ignore [missing-attribute] or (self.ft_manager and self.ft_manager.participating_rank() == 0) ) ): diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 071af84d54..7a1c1fcad6 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -41,6 +41,7 @@ def __iter__(self): ... +# pyrefly: ignore [inconsistent-inheritance] class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): """Dataloader that is aware of distributed data parallelism. @@ -58,7 +59,7 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): dp_rank: int dp_world_size: int - batch_size: int + batch_size: int | None def __init__( self, diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 5d64d34b09..d95470c47d 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib +import importlib.util from contextlib import nullcontext from datetime import timedelta from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union @@ -165,4 +165,5 @@ def maybe_semi_sync_training( raise ValueError( f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported." ) + # pyrefly: ignore [no-matching-overload] return nullcontext() diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 6384feb641..15a3fc6bd1 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -176,6 +176,8 @@ def linear_warmup_stable_decay( curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + else: + raise ValueError(f"Unknown lr_decay_type: {lr_decay_type}") curr_adjustment = min_lr_factor + (1 - min_lr_factor) * curr_adjustment return curr_adjustment diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 6905fb5b53..6f50337473 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -40,15 +40,21 @@ class DeviceMemoryMonitor: def __init__(self, device: str = f"{device_type}:0"): + # pyrefly: ignore [read-only] self.device = torch.device(device) # device object + # pyrefly: ignore [missing-attribute] self.device_name = device_module.get_device_name(self.device) + # pyrefly: ignore [missing-attribute] self.device_index = device_module.current_device() + # pyrefly: ignore [missing-attribute] self.device_capacity = device_module.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) + # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() + # pyrefly: ignore [missing-attribute] device_module.empty_cache() def _to_gib(self, memory_in_bytes): @@ -61,6 +67,7 @@ def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): + # pyrefly: ignore [missing-attribute] device_info = device_module.memory_stats(self.device) max_active = device_info.get("active_bytes.all.peak", -1) @@ -91,6 +98,7 @@ def get_peak_stats(self): ) def reset_peak_stats(self): + # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() @@ -341,7 +349,7 @@ class MetricsProcessor: device_memory_monitor: DeviceMemoryMonitor color: utils.NoColor | utils.Color - gpu_peak_flops: int + gpu_peak_flops: float ntokens_since_last_log: int data_loading_times: list[float] time_last_log: float diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 80557366da..2b08142f97 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -8,6 +8,7 @@ from typing import Any, Generic, Iterator, TypeVar import torch +import torch.distributed.tensor import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl from torch.distributed.checkpoint.state_dict import ( @@ -88,6 +89,7 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: return len(self.optimizers) + # pyrefly: ignore [bad-override] def step(self, *args, **kwargs) -> None: for optimizer in self.optimizers: optimizer.step(*args, **kwargs) @@ -170,9 +172,11 @@ def optim_hook(param) -> None: ) self._post_init(all_params, optimizer_kwargs) + # pyrefly: ignore [bad-override] def step(self) -> None: pass + # pyrefly: ignore [bad-override] def zero_grad(self) -> None: pass @@ -343,9 +347,12 @@ def build_optimizers_with_moe_load_balancing( def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if transformer_block.moe_enabled: # Assumption: load_balance_coeff is set universally on all moe blocks. + # pyrefly: ignore [missing-attribute] return bool(transformer_block.moe.load_balance_coeff) return False @@ -364,11 +371,15 @@ def _update_expert_bias( # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] if transformer_block.moe.load_balance_coeff is None: return + # pyrefly: ignore [missing-attribute] tokens_per_expert = transformer_block.moe.tokens_per_expert if _is_recomputation_enabled(transformer_block): # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice. @@ -398,9 +409,12 @@ def _update_expert_bias( moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] moe = transformer_block.moe tokens_per_expert = tokens_per_expert_by_layer[ diff --git a/torchtitan/components/quantization/__init__.py b/torchtitan/components/quantization/__init__.py index de94c37b3e..49faf60733 100644 --- a/torchtitan/components/quantization/__init__.py +++ b/torchtitan/components/quantization/__init__.py @@ -42,7 +42,7 @@ def _validate(job_config: JobConfig): # quantization converter format: # `quantize.[linear | grouped_mm].[float8 | mx]` quantization_type = lambda converter: converter.split(".")[-1] - existing_quantization_converter = None + existing_quantization_converter: str | None = None for converter in job_config.model.converters: if "quantize" in converter: if existing_quantization_converter is None: diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 86932a17bd..9b575876e7 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -6,6 +6,7 @@ from functools import partial import torch +import torch._inductor.config import torch.nn as nn from torchtitan.components.quantization import ( FP8_GROUP_ALIGNMENT_SIZE, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index a474cc3918..f1c0e09574 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -57,14 +57,19 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): MXLinearConfig as TorchAOMXLinearConfig, ) + # pyrefly: ignore [bad-assignment] mx_job_config: TorchAOMXLinearConfig = job_config.quantize.linear.mx + # pyrefly: ignore [missing-attribute] config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name) + # pyrefly: ignore [missing-attribute] config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ] + # pyrefly: ignore [missing-attribute] self.filter_fqns = mx_job_config.filter_fqns self.config = config self.enabled = True + # pyrefly: ignore [missing-attribute] logger.info(f"MX training active with recipe {mx_job_config.recipe_name}") def convert(self, model: nn.Module): diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 022fcbc266..aca2300abe 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -56,6 +56,7 @@ def __init__( # Initialize BOS/EOS token attributes (frequently used) self.bos_id = None + # pyrefly: ignore [bad-assignment] self.eos_id = None self.bos_token = None self.eos_token = None @@ -144,10 +145,13 @@ def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: tokenizer = Tokenizer(bpe_model) # Configure GPT-2 style components for proper space handling + # pyrefly: ignore [read-only] tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( add_prefix_space=False ) + # pyrefly: ignore [read-only] tokenizer.decoder = decoders.ByteLevel() + # pyrefly: ignore [read-only] tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) return tokenizer diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 93fb68a3cc..4673807347 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Generator +from collections.abc import Callable +from contextlib import AbstractContextManager +from typing import TypeAlias import torch import torch.nn as nn @@ -19,6 +21,11 @@ from torchtitan.tools import utils from torchtitan.tools.logging import logger +ValidationContext: TypeAlias = Callable[ + [AbstractContextManager[None] | None], + AbstractContextManager[None], +] + class BaseValidator: def __init__(self, job_config: JobConfig): @@ -52,8 +59,8 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -83,6 +90,7 @@ def __init__( ) @torch.no_grad() + # pyrefly: ignore [bad-override] def validate( self, model_parts: list[nn.Module], @@ -98,6 +106,7 @@ def validate( device_type = utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -186,8 +195,8 @@ def build_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -203,6 +212,7 @@ def build_validator( loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, + # pyrefly: ignore [bad-argument-type] metrics_processor=metrics_processor, pp_schedule=pp_schedule, pp_has_first_stage=pp_has_first_stage, diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 10f4440a4c..79d95c350e 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -16,6 +16,7 @@ try: import tomllib except ModuleNotFoundError: + # pyrefly: ignore [missing-import] import tomli as tomllib from torchtitan.tools.logging import logger @@ -253,7 +254,10 @@ def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): # # ----------------------------------------------------------------------------- + # pyrefly: ignore [missing-import] from rich import print as rprint + + # pyrefly: ignore [missing-import] from rich.pretty import Pretty config_manager = ConfigManager() diff --git a/torchtitan/distributed/__init__.py b/torchtitan/distributed/__init__.py index 63690a660b..f335916595 100644 --- a/torchtitan/distributed/__init__.py +++ b/torchtitan/distributed/__init__.py @@ -65,7 +65,10 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: device_mesh, None, partial( - self._prepare_input_fn, self.input_layout, self.desired_input_layout + # pyrefly: ignore [bad-argument-type] + self._prepare_input_fn, + self.input_layout, + self.desired_input_layout, ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 0eecde9052..c0b550a5c1 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -11,6 +11,7 @@ from collections import defaultdict import torch +import torch._functorch.config import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -221,6 +222,7 @@ def apply_ac( torch._functorch.config.activation_memory_budget = ac_config.memory_budget logger.info(f"Selected {ac_config.memory_budget} budget option") else: + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, @@ -229,6 +231,7 @@ def apply_ac( model_compile_enabled=model_compile_enabled, op_sac_save_list=op_sac_save_list, ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index 5a4a5d9dd0..5def0e40e6 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -91,7 +91,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -145,6 +147,7 @@ def is_coordination_enabled(self): class SyncHook(torch.autograd.Function): @staticmethod + # pyrefly: ignore [bad-override] def forward(ctx, x, hook_name=""): ctx.hook_name = hook_name # handle edge case for transformer level boundary @@ -158,6 +161,7 @@ def forward(ctx, x, hook_name=""): return x @staticmethod + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): hook_name = ctx.hook_name @@ -260,19 +264,24 @@ def overlap_callback(action: _Action, ctx: _PipelineContext): # Shared container for exception from backward thread def run_backward(): + # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(backward_stage) # Set the backward thread to use the same stream as forward + # pyrefly: ignore [missing-attribute] device_module.set_stream(main_stream) with record_function( f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" ): loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + # pyrefly: ignore [missing-attribute] schedule.backward_counter[backward_stage_index] += 1 last_backward = ( + # pyrefly: ignore [missing-attribute] schedule.backward_counter[backward_stage_index] == schedule._n_microbatches ) backward_stage.backward_one_chunk( + # pyrefly: ignore [bad-argument-type] backward_mb_index, loss=loss, full_backward=True, @@ -282,14 +291,19 @@ def run_backward(): if backward_is_prev_stage_on_this_rank: stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( backward_stage.get_local_bwd_output(backward_mb_index), + # pyrefly: ignore [bad-argument-type] backward_mb_index, ) def run_forward(): + # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(forward_stage) output = forward_stage.forward_one_chunk( + # pyrefly: ignore [bad-argument-type] forward_mb_index, + # pyrefly: ignore [bad-index, unsupported-operation] arg_mbs[forward_mb_index], + # pyrefly: ignore [bad-index, unsupported-operation] kwarg_mbs[forward_mb_index], ) schedule._maybe_compute_loss( @@ -297,7 +311,9 @@ def run_forward(): ) if forward_is_next_stage_on_this_rank: stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( - output, forward_mb_index + output, + # pyrefly: ignore [bad-argument-type] + forward_mb_index, ) # Run forward and backward in parallel diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 932a7e4aa1..60de27b276 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -81,6 +81,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, self._partition_fn, + # pyrefly: ignore [bad-argument-type] self._prepare_input_fn, ) @@ -184,7 +185,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -210,18 +213,21 @@ def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> N # w1 shape = (experts, out_dim, in_dim) mod.register_parameter( "w1", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w1, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding # w2 shape = (experts, in_dim, out_dim) mod.register_parameter( "w2", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w2, device_mesh, [Shard(0), Shard(2)])), ) # Row-wise sharding # w3 shape = (experts, out_dim, in_dim) mod.register_parameter( "w3", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w3, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding @@ -234,7 +240,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -296,6 +304,8 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=None, + # pyrefly: ignore [bad-argument-type] input_fn=self._prepare_inputput_fn, + # pyrefly: ignore [bad-argument-type] output_fn=self._prepare_output_fn, ) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..187a363097 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -26,7 +26,7 @@ class ParallelDims: etp: int world_size: int - _world_mesh: DeviceMesh = None + _world_mesh: DeviceMesh | None = None def __post_init__(self): self._validate() @@ -105,7 +105,7 @@ def _build_mesh_with_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) # Create all the submesh here to ensure all required process groups are # initialized: @@ -156,7 +156,7 @@ def _build_mesh_without_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) # Create all the submesh here to ensure all required process groups are # initialized: diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index bafefddbec..bef597be24 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -200,7 +200,9 @@ def build_pipeline_schedule( f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) + # pyrefly: ignore [bad-instantiation] schedule = schedule_class( + # pyrefly: ignore [bad-argument-type] stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), @@ -225,6 +227,7 @@ def build_pipeline_schedule( "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " "and _PipelineScheduleRuntime support csv schedules" ) + # pyrefly: ignore [missing-attribute] schedule._load_csv(pp_schedule_csv) return schedule @@ -445,7 +448,7 @@ def _build_stage_from_modules( "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" ) - def _get_stage_indices() -> tuple[int]: + def _get_stage_indices() -> tuple[int, ...]: """ Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule @@ -464,6 +467,8 @@ def _get_stage_indices() -> tuple[int]: zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) ) return stage_v_pairs[pp_rank] + else: + raise ValueError(f"Unknown style {style}") for stage_idx in _get_stage_indices(): module_names = module_names_per_stage[stage_idx] diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py index 04e4e36c3a..59fffc86a2 100644 --- a/torchtitan/distributed/tensor_parallel.py +++ b/torchtitan/distributed/tensor_parallel.py @@ -6,6 +6,7 @@ import torch +import torch._inductor.config from torch.distributed.device_mesh import DeviceMesh from torchtitan.config import JobConfig diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 6a73ffd083..811e062958 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -7,12 +7,16 @@ import contextlib import math import os -from collections.abc import Generator, Iterable +from abc import abstractmethod +from collections.abc import Iterable from datetime import timedelta +from typing import Protocol import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d +import torch.distributed.tensor._random +import torch.distributed.tensor.parallel from torch import distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor @@ -174,11 +178,15 @@ def set_determinism( # Filter out all distinct dimensions to get duplicate_seed_mesh duplicate_seed_mesh_dims = [ name + # pyrefly: ignore [not-iterable] for name in world_mesh.mesh_dim_names if name not in distinct_dims_in_mesh ] duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None + # pyrefly: ignore [bad-index] + world_mesh[duplicate_seed_mesh_dims] + if duplicate_seed_mesh_dims + else None ) else: duplicate_seed_mesh = world_mesh @@ -192,6 +200,7 @@ def set_determinism( # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. # IF PP is also used, this seed is unique per PP rank. if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: + # pyrefly: ignore [bad-argument-type] torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) @@ -205,11 +214,11 @@ def create_context_parallel_ctx( try: from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method - except ImportError: - print( + except ImportError as e: + raise ValueError( f"PyTorch version {torch.__version__} does not include the experimental " "Context Parallel API. Please update to a newer version." - ) + ) from e set_rotate_method(cp_rotate_method) return context_parallel( @@ -220,9 +229,18 @@ def create_context_parallel_ctx( ) -def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: +class TrainContext(Protocol): + @abstractmethod + def __call__( + self, + cp_context: contextlib.AbstractContextManager[None] | None = None, + ) -> contextlib.AbstractContextManager[None]: + pass + + +def get_train_context(enable_loss_parallel: bool) -> TrainContext: @contextlib.contextmanager - def context(cp_context: Generator[None, None, None] | None = None): + def context(cp_context: contextlib.AbstractContextManager[None] | None = None): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) @@ -236,8 +254,8 @@ def context(cp_context: Generator[None, None, None] | None = None): def maybe_enable_amp( - parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device -) -> Generator[None, None, None]: + parallel_dims: ParallelDims, mixed_precision_param: str, device_type: str +) -> contextlib.AbstractContextManager[None]: if parallel_dims.fsdp_enabled: # FSDP handles mixed precision internally logger.info("Mixed precision training is handled by fully_shard") @@ -252,6 +270,7 @@ def maybe_enable_amp( else: # the following code will only be executed for DDP or single-device training logger.info("Mixed precision training is handled by AMP") + # pyrefly: ignore [bad-return] return torch.autocast( device_type, dtype=TORCH_DTYPE_MAP[mixed_precision_param], @@ -367,7 +386,9 @@ def set_pg_timeouts(timeout, world_mesh): # otherwise, some ranks may issue collectives with the new/shorter timeout and # those may time out, before other ranks have finished with initialization done # under the old/slow timeout. + # pyrefly: ignore [missing-attribute] torch.distributed.barrier(device_ids=[device_module.current_device()]) + # pyrefly: ignore [missing-attribute] device_module.synchronize() groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] @@ -477,6 +498,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) + # pyrefly: ignore [not-iterable] if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) @@ -491,6 +513,7 @@ def _clip_grad_norm_with_ep( if isinstance(ep_grads_total_norm, DTensor): ep_grads_total_norm = ep_grads_total_norm.full_tensor() + # pyrefly: ignore [missing-attribute] non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 493cd1abb4..63790b8862 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -153,7 +153,7 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict = {"token_buffer": self._token_buffer} + _state_dict: dict[str, Any] = {"token_buffer": self._token_buffer} if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 819dbd57bc..b04a6a136e 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,6 +20,7 @@ ) from torch.nn.attention.varlen import varlen_attn +from torch.types import Number __all__ = [ @@ -43,8 +44,8 @@ class VarlenMetadata(NamedTuple): cu_seq_q: torch.Tensor cu_seq_k: torch.Tensor - max_q: int - max_k: int + max_q: Number + max_k: Number class VarlenAttentionWrapper(torch.nn.Module): @@ -66,8 +67,11 @@ def forward( max_k = attention_masks.max_k n_local_heads = xq.shape[1] + # pyrefly: ignore [no-matching-overload] xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) return VarlenAttentionWrapper._compiled_varlen_attn( @@ -146,7 +150,7 @@ class ScaledDotProductAttentionWrapper(torch.nn.Module): """ # TODO: remove sdpa_backends after PyTorch 2.9 is released. - sdpa_backends: ClassVar[list[SDPBackend]] = [] + sdpa_backends: list[SDPBackend] = [] def __init__(self) -> None: super().__init__() diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c068e60a30..63fb910376 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -119,6 +119,7 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -225,6 +226,7 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -246,6 +248,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if transformer_block.attention.q_lora_rank == 0: layer_plan.update( { @@ -263,6 +266,7 @@ def apply_non_moe_tp( } ) + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -277,8 +281,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index e683905878..64a9d2bb81 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -101,16 +101,17 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 5b17ad0acf..26e0cff2f3 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -240,6 +240,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def forward( @@ -433,6 +434,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index fd4ec30284..7fd6743600 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -106,6 +106,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "moe.experts" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -128,15 +129,19 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(0, self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -186,6 +191,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -194,6 +200,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) diff --git a/torchtitan/models/flux/__init__.py b/torchtitan/models/flux/__init__.py index 0fee76e60d..d5ec94b1d6 100644 --- a/torchtitan/models/flux/__init__.py +++ b/torchtitan/models/flux/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "FluxModelArgs", "FluxModel", + # pyrefly: ignore [missing-module-attribute] "flux_configs", "parallelize_flux", ] diff --git a/torchtitan/models/flux/flux_datasets.py b/torchtitan/models/flux/flux_datasets.py index f3cf283aa6..906b669001 100644 --- a/torchtitan/models/flux/flux_datasets.py +++ b/torchtitan/models/flux/flux_datasets.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional import numpy as np -import PIL +import PIL.Image import torch from datasets import Dataset, load_dataset @@ -271,6 +271,7 @@ def __iter__(self): # skip low quality image or image with color channel = 1 if sample_dict["image"] is None: + # pyrefly: ignore [missing-attribute] sample = sample.get("__key__", "unknown") logger.warning( f"Low quality image {sample} is skipped in Flux Dataloader." @@ -279,6 +280,7 @@ def __iter__(self): # Classifier-free guidance: Replace some of the strings with empty strings. # Distinct random seed is initialized at the beginning of training for each FSDP rank. + # pyrefly: ignore [missing-attribute] dropout_prob = self.job_config.training.classifier_free_guidance_prob if dropout_prob > 0.0: if torch.rand(1).item() < dropout_prob: diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index b89887ad51..bffdb2a2e7 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,6 +25,7 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) + # pyrefly: ignore [missing-attribute] original_prompts = open(config.inference.prompts_path).readlines() total_prompts = len(original_prompts) @@ -45,10 +46,12 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts + # pyrefly: ignore [missing-attribute] bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, + # pyrefly: ignore [missing-attribute] config.inference.save_img_folder, ) # Create mapping from local indices to global prompt indices @@ -59,6 +62,7 @@ def inference(config: JobConfig): device=trainer.device, dtype=trainer._dtype, job_config=trainer.job_config, + # pyrefly: ignore [bad-argument-type] model=trainer.model_parts[0], prompt=prompts[i : i + bs], autoencoder=trainer.autoencoder, diff --git a/torchtitan/models/flux/inference/sampling.py b/torchtitan/models/flux/inference/sampling.py index f43d0fc2c5..5ee48ab60f 100644 --- a/torchtitan/models/flux/inference/sampling.py +++ b/torchtitan/models/flux/inference/sampling.py @@ -93,10 +93,13 @@ def generate_image( prompt = [prompt] # allow for packing and conversion to latent space. Use the same resolution as training time. + # pyrefly: ignore [missing-attribute] img_height = 16 * (job_config.training.img_size // 16) + # pyrefly: ignore [missing-attribute] img_width = 16 * (job_config.training.img_size // 16) enable_classifier_free_guidance = ( + # pyrefly: ignore [missing-attribute] job_config.validation.enable_classifier_free_guidance ) @@ -104,7 +107,9 @@ def generate_image( clip_tokens = clip_tokenizer.encode(prompt) t5_tokens = t5_tokenizer.encode(prompt) if len(prompt) == 1: + # pyrefly: ignore [missing-attribute] clip_tokens = clip_tokens.unsqueeze(0) + # pyrefly: ignore [missing-attribute] t5_tokens = t5_tokens.unsqueeze(0) batch = preprocess_data( @@ -113,6 +118,7 @@ def generate_image( autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, + # pyrefly: ignore [bad-argument-type] batch={ "clip_tokens": clip_tokens, "t5_tokens": t5_tokens, @@ -124,7 +130,9 @@ def generate_image( empty_clip_tokens = clip_tokenizer.encode("") empty_t5_tokens = t5_tokenizer.encode("") + # pyrefly: ignore [missing-attribute] empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) + # pyrefly: ignore [missing-attribute] empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) empty_batch = preprocess_data( @@ -145,16 +153,24 @@ def generate_image( model=model, img_width=img_width, img_height=img_height, + # pyrefly: ignore [missing-attribute] denoising_steps=job_config.validation.denoising_steps, clip_encodings=batch["clip_encodings"], t5_encodings=batch["t5_encodings"], enable_classifier_free_guidance=enable_classifier_free_guidance, empty_t5_encodings=( - empty_batch["t5_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["t5_encodings"] + if enable_classifier_free_guidance + else None ), empty_clip_encodings=( - empty_batch["clip_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["clip_encodings"] + if enable_classifier_free_guidance + else None ), + # pyrefly: ignore [missing-attribute] classifier_free_guidance_scale=job_config.validation.classifier_free_guidance_scale, ) @@ -190,7 +206,9 @@ def denoise( if enable_classifier_free_guidance: # Double batch size for CFG: [unconditional, conditional] latents = torch.cat([latents, latents], dim=0) + # pyrefly: ignore [no-matching-overload] t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0) + # pyrefly: ignore [no-matching-overload] clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0) bsz *= 2 diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fc9c926af0..b27fa93a31 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any import torch import torch.nn as nn @@ -77,7 +78,7 @@ def apply_fsdp( cpu_offload (bool): Whether to offload model parameters to CPU. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() @@ -88,21 +89,27 @@ def apply_fsdp( model.txt_in, ] for layer in linear_layers: + # pyrefly: ignore [no-matching-overload] fully_shard(layer, **fsdp_config) + # pyrefly: ignore [not-iterable] for block in model.double_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) + # pyrefly: ignore [not-iterable] for block in model.single_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) # apply FSDP to last layer. Set reshard_after_forward=False for last layer to avoid gather right after reshard + # pyrefly: ignore [no-matching-overload] fully_shard(model.final_layer, **fsdp_config, reshard_after_forward=False) # Wrap all the rest of model @@ -112,12 +119,16 @@ def apply_fsdp( def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" + # pyrefly: ignore [missing-attribute] for layer_id, block in model.double_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.double_blocks.register_module(layer_id, block) + # pyrefly: ignore [missing-attribute] for layer_id, block in model.single_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.single_blocks.register_module(layer_id, block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") @@ -139,7 +150,7 @@ def parallelize_encoders( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - fsdp_config = { + fsdp_config: dict[str, Any] = { "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy, } @@ -148,8 +159,10 @@ def parallelize_encoders( # NOTE: only apply FSDP to the T5 encoder, not the CLIP text encoder. # CLIP Text encoder has low computation / communication ratio, so it's not necessary to apply FSDP to it. + # pyrefly: ignore [missing-attribute] for block in t5_model.hf_module.encoder.block: fully_shard(block, **fsdp_config) + # pyrefly: ignore [no-matching-overload] fully_shard(t5_model.hf_module, **fsdp_config) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/flux/model/autoencoder.py b/torchtitan/models/flux/model/autoencoder.py index dc6fb1d061..9ca46dff96 100644 --- a/torchtitan/models/flux/model/autoencoder.py +++ b/torchtitan/models/flux/model/autoencoder.py @@ -19,7 +19,7 @@ class AutoEncoderParams: in_channels: int = 3 ch: int = 128 out_ch: int = 3 - ch_mult: tuple[int] = (1, 2, 4, 4) + ch_mult: tuple[int, ...] = (1, 2, 4, 4) num_res_blocks: int = 2 z_channels: int = 16 scale_factor: float = 0.3611 @@ -191,17 +191,24 @@ def forward(self, x: Tensor) -> Tensor: hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): + # pyrefly: ignore [bad-index, not-callable] h = self.down[i_level].block[i_block](hs[-1]) + # pyrefly: ignore [bad-argument-type] if len(self.down[i_level].attn) > 0: + # pyrefly: ignore [bad-index, not-callable] h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: + # pyrefly: ignore [not-callable] hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # end h = self.norm_out(h) @@ -276,8 +283,11 @@ def forward(self, z: Tensor) -> Tensor: h = self.conv_in(z) # middle + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # cast to proper dtype @@ -285,10 +295,14 @@ def forward(self, z: Tensor) -> Tensor: # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): + # pyrefly: ignore [bad-index, not-callable] h = self.up[i_level].block[i_block](h) + # pyrefly: ignore [bad-argument-type] if len(self.up[i_level].attn) > 0: + # pyrefly: ignore [bad-index, not-callable] h = self.up[i_level].attn[i_block](h) if i_level != 0: + # pyrefly: ignore [not-callable] h = self.up[i_level].upsample(h) # end @@ -321,6 +335,7 @@ def __init__(self, params: AutoEncoderParams): resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, @@ -330,6 +345,7 @@ def __init__(self, params: AutoEncoderParams): in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, diff --git a/torchtitan/models/flux/model/hf_embedder.py b/torchtitan/models/flux/model/hf_embedder.py index 90be8767a9..89bed4d248 100644 --- a/torchtitan/models/flux/model/hf_embedder.py +++ b/torchtitan/models/flux/model/hf_embedder.py @@ -19,6 +19,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize CLIP model with random weights for test purpose only self.hf_module = CLIPTextModel._from_config( + # pyrefly: ignore [missing-attribute] CLIPTextModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) @@ -31,6 +32,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize T5 model with random weights for test purpose only self.hf_module = T5EncoderModel._from_config( + # pyrefly: ignore [missing-attribute] T5EncoderModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) diff --git a/torchtitan/models/flux/model/layers.py b/torchtitan/models/flux/model/layers.py index 923c5a422c..30ba52d3a3 100644 --- a/torchtitan/models/flux/model/layers.py +++ b/torchtitan/models/flux/model/layers.py @@ -6,6 +6,7 @@ # imported from black-forest-labs/FLUX import math +from collections.abc import Sequence from dataclasses import dataclass import torch @@ -34,7 +35,7 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: Sequence[int]): super().__init__() self.dim = dim self.theta = theta @@ -213,7 +214,9 @@ def init_weights(self): self.txt_mlp[0], self.txt_mlp[2], ): + # pyrefly: ignore [bad-argument-type] nn.init.xavier_uniform_(layer.weight) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(layer.bias, 0) # initialize Modulation layers, SelfAttention layers @@ -346,7 +349,9 @@ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): ) def init_weights(self): + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) diff --git a/torchtitan/models/flux/model/model.py b/torchtitan/models/flux/model/model.py index 6cfb02c9c0..d0f5592871 100644 --- a/torchtitan/models/flux/model/model.py +++ b/torchtitan/models/flux/model/model.py @@ -51,7 +51,9 @@ def __init__(self, model_args: FluxModelArgs): self.hidden_size = model_args.hidden_size self.num_heads = model_args.num_heads self.pe_embedder = EmbedND( - dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + dim=pe_dim, + theta=model_args.theta, + axes_dim=model_args.axes_dim, ) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) @@ -95,8 +97,10 @@ def init_weights(self, buffer_device=None): # Initialize transformer blocks: for block in self.single_blocks: + # pyrefly: ignore [not-callable] block.init_weights() for block in self.double_blocks: + # pyrefly: ignore [not-callable] block.init_weights() # Zero-out output layers: diff --git a/torchtitan/models/flux/model/state_dict_adapter.py b/torchtitan/models/flux/model/state_dict_adapter.py index c976df6919..2526bcd521 100644 --- a/torchtitan/models/flux/model/state_dict_adapter.py +++ b/torchtitan/models/flux/model/state_dict_adapter.py @@ -58,6 +58,7 @@ def __init__(self, model_args: FluxModelArgs, hf_assets_path: str | None): if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = indx else: @@ -173,6 +174,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): # Extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -242,6 +244,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): # extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -273,6 +276,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for tt_fqn, hf_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", tt_fqn).group(0) tt_abstract_key = re.sub(r"(\d+)", "{}", tt_fqn, count=1) combine_values = [] diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index b5cca546b9..06fbde2bbb 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -46,6 +46,7 @@ def _pad_and_chunk_tokens( def get_vocab_size(self) -> int: return self.tiktokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode(self, text: str | list[str]) -> torch.Tensor: """ Use TikTokenizer to encode the text into tokens, and then pad and chunk the tokens to max_length. @@ -72,6 +73,7 @@ def encode(self, text: str | list[str]) -> torch.Tensor: tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id) return torch.tensor(tokens) + # pyrefly: ignore [bad-override] def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. @@ -96,10 +98,12 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar self.is_clip = "clip" in model_path.lower() if self.is_clip: + # pyrefly: ignore [bad-assignment] self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) else: + # pyrefly: ignore [bad-assignment] self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) @@ -107,6 +111,7 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar def get_vocab_size(self) -> int: return self._tokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode( self, s: str | list[str], @@ -125,6 +130,7 @@ def encode( )["input_ids"] return tokens + # pyrefly: ignore [bad-override] def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. @@ -136,11 +142,15 @@ def build_flux_tokenizer(job_config: JobConfig) -> tuple[BaseTokenizer, BaseToke """ Build the tokenizer for Flux. """ + # pyrefly: ignore [missing-attribute] t5_tokenizer_path = job_config.encoder.t5_encoder + # pyrefly: ignore [missing-attribute] clip_tokenzier_path = job_config.encoder.clip_encoder + # pyrefly: ignore [missing-attribute] max_t5_encoding_len = job_config.encoder.max_t5_encoding_len # NOTE: This tokenizer is used for offline CI and testing only, borrowed from llama3 tokenizer + # pyrefly: ignore [missing-attribute] if job_config.training.test_mode: tokenizer_class = FluxTestTokenizer t5_tokenizer_path = clip_tokenzier_path = job_config.model.hf_assets_path diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 5af9959050..3e008fba59 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -48,23 +48,31 @@ def __init__(self, job_config: JobConfig): model_args = self.train_spec.model_args[job_config.model.flavor] self.autoencoder = load_ae( + # pyrefly: ignore [missing-attribute] job_config.encoder.autoencoder_path, + # pyrefly: ignore [missing-attribute] model_args.autoencoder_params, device=self.device, dtype=self._dtype, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ) self.clip_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.clip_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) self.t5_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.t5_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) # Apply FSDP to the T5 model / CLIP model + # pyrefly: ignore [bad-assignment] self.t5_encoder, self.clip_encoder = parallelize_encoders( t5_model=self.t5_encoder, clip_model=self.clip_encoder, @@ -73,6 +81,7 @@ def __init__(self, job_config: JobConfig): ) if job_config.validation.enable: + # pyrefly: ignore [missing-attribute] self.validator.flux_init( device=self.device, _dtype=self._dtype, @@ -164,6 +173,7 @@ def forward_backward_step( loss = self.loss_fn(latent_noise_pred, target) # latent_noise_pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory + # pyrefly: ignore [unsupported-delete] del (latent_noise_pred, noise, target) loss.backward() diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 189385e0f2..32fa7b9f55 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Generator +from contextlib import AbstractContextManager import torch import torch.nn as nn @@ -15,7 +15,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.components.validate import Validator +from torchtitan.components.validate import ValidationContext, Validator from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.flux.flux_datasets import build_flux_validation_dataloader @@ -53,8 +53,8 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -63,6 +63,7 @@ def __init__( self.job_config = job_config self.parallel_dims = parallel_dims self.loss_fn = loss_fn + # pyrefly: ignore [missing-attribute] self.all_timesteps = self.job_config.validation.all_timesteps self.validation_dataloader = build_flux_validation_dataloader( job_config=job_config, @@ -74,6 +75,7 @@ def __init__( ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp + # pyrefly: ignore [bad-assignment] self.metrics_processor = metrics_processor self.t5_tokenizer, self.clip_tokenizer = build_flux_tokenizer(self.job_config) @@ -91,6 +93,7 @@ def flux_init( t5_encoder: FluxEmbedder, clip_encoder: FluxEmbedder, ): + # pyrefly: ignore [read-only] self.device = device self._dtype = _dtype self.autoencoder = autoencoder @@ -109,9 +112,12 @@ def validate( model.eval() # Disable cfg dropout during validation + # pyrefly: ignore [missing-attribute] training_cfg_prob = self.job_config.training.classifier_free_guidance_prob + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = 0.0 + # pyrefly: ignore [missing-attribute] save_img_count = self.job_config.validation.save_img_count parallel_dims = self.parallel_dims @@ -120,6 +126,7 @@ def validate( device_type = dist_utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -137,6 +144,7 @@ def validate( device=self.device, dtype=self._dtype, job_config=self.job_config, + # pyrefly: ignore [bad-argument-type] model=model, prompt=p, autoencoder=self.autoencoder, @@ -150,6 +158,7 @@ def validate( name=f"image_rank{str(torch.distributed.get_rank())}_{step}.png", output_dir=os.path.join( self.job_config.job.dump_folder, + # pyrefly: ignore [missing-attribute] self.job_config.validation.save_img_folder, ), x=image, @@ -270,6 +279,7 @@ def validate( model.train() # re-enable cfg dropout for training + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = training_cfg_prob @@ -280,8 +290,8 @@ def build_flux_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 13a968be96..63bbc19ff6 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -101,6 +101,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -202,6 +203,7 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -226,8 +228,10 @@ def apply_tp( } parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -242,10 +246,12 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile( transformer_block, backend=compile_config.backend, fullgraph=True ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") @@ -280,6 +286,7 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: + # pyrefly: ignore [bad-typed-dict-key] fsdp_config["offload_policy"] = CPUOffloadPolicy() match reshard_after_forward_policy: @@ -297,11 +304,13 @@ def apply_fsdp( ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): fully_shard( transformer_block, @@ -311,6 +320,7 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, @@ -327,6 +337,7 @@ def apply_ddp( if enable_compile: torch._dynamo.config.optimize_ddp = "ddp_optimizer" + # pyrefly: ignore [invalid-param-spec] replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 81680074eb..79e97dab4c 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -62,9 +62,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_dense_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 8982fcca9f..cafd58a52e 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -223,8 +223,10 @@ def __init__(self, model_args: TransformerModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case "varlen": + # pyrefly: ignore [bad-assignment] self.inner_attention = VarlenAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): @@ -474,6 +476,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights() if self.norm is not None: self.norm.reset_parameters() @@ -569,12 +572,15 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer( h, self.freqs_cis, attention_masks=attention_masks, positions=positions ) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..f951edd75a 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -81,6 +81,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] # We need to permute the weights in wq and wk layer in order to account for the difference between @@ -115,6 +116,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] @@ -132,5 +134,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index b8b3470d37..112153390f 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any + import torch import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -130,6 +132,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -245,6 +248,7 @@ def apply_non_moe_tp( ) # Apply tensor + sequence parallelism to every transformer block + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -260,6 +264,7 @@ def apply_non_moe_tp( "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -274,8 +279,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -315,7 +322,7 @@ def apply_fsdp( """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() @@ -334,12 +341,14 @@ def apply_fsdp( ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping # - the router and the shared experts are sharded together with the TransformerBlock @@ -386,6 +395,7 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, @@ -400,49 +410,65 @@ def apply_fsdp( return # forward + # pyrefly: ignore [not-callable] transformer_blocks = list(model.layers.values()) next_transformer_blocks = transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.tok_embeddings is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) for transformer_block, next_transformer_block in zip( transformer_blocks, next_transformer_blocks ): if next_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if next_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( + # pyrefly: ignore [missing-attribute] [next_transformer_block, next_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [next_transformer_block] ) elif model.norm is not None and model.output is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [model.norm, model.output] ) # backward + # pyrefly: ignore [not-callable] reversed_transformer_blocks = list(reversed(model.layers.values())) prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.norm is not None and model.output is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) for transformer_block, prev_transformer_block in zip( reversed_transformer_blocks, prev_transformer_blocks ): if prev_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if prev_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( + # pyrefly: ignore [missing-attribute] [prev_transformer_block, prev_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( [prev_transformer_block] ) elif model.tok_embeddings is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) @@ -456,7 +482,9 @@ def apply_moe_ep_tp( ): assert ep_mesh is not None or tp_mesh is not None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue @@ -478,9 +506,12 @@ def apply_moe_ep_tp( # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + # pyrefly: ignore [missing-attribute] if transformer_block.moe.shared_experts is not None: # input Replicate, output Partial + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update( { "moe.shared_experts.w1": ColwiseParallel(), @@ -491,8 +522,10 @@ def apply_moe_ep_tp( } ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=moe_layer_plan, ) @@ -513,6 +546,7 @@ def apply_moe_ep_tp( experts_plan = DualPipeExpertParallel(experts_plan) parallelize_module( + # pyrefly: ignore [missing-attribute] module=transformer_block.moe.experts, device_mesh=experts_mesh, parallelize_plan=experts_plan, @@ -528,7 +562,9 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # but it is experimental. torch._dynamo.config.capture_scalar_outputs = True # Workaround for https://github.com/pytorch/pytorch/issues/166926 + # pyrefly: ignore [missing-attribute] torch._C._dynamo.eval_frame._set_lru_cache(False) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): if transformer_block.moe_enabled: # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break @@ -582,6 +618,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b fullgraph=True, ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) # Patch some globals only once (apply_compile is called multiple times for PP setup) diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index a277ca382e..3520e7e519 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -86,9 +86,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 7c4f073e19..e08f733f28 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -231,6 +231,7 @@ def __init__( case "flex": self.inner_attention = FlexAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): @@ -513,6 +514,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -590,11 +592,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.freqs_cis, attention_masks, positions) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/model/state_dict_adapter.py b/torchtitan/models/llama4/model/state_dict_adapter.py index 182981c665..c272b2ac10 100644 --- a/torchtitan/models/llama4/model/state_dict_adapter.py +++ b/torchtitan/models/llama4/model/state_dict_adapter.py @@ -52,6 +52,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_combine = defaultdict(dict) for key, value in state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -77,6 +78,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key = ( "language_model.model.layers.{}.feed_forward.experts.gate_up_proj" ) + # pyrefly: ignore [unnecessary-comparison] if hf_abstract_key is None: continue to_combine[hf_abstract_key.format(layer_num)][ @@ -85,6 +87,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for hf_fqn, tt_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", hf_fqn).group(0) combine_values = [] # put into correct order to combine @@ -106,6 +109,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: diff --git a/torchtitan/models/moe/kernels.py b/torchtitan/models/moe/kernels.py index 7aac7b3ac4..a1b1d17771 100644 --- a/torchtitan/models/moe/kernels.py +++ b/torchtitan/models/moe/kernels.py @@ -92,8 +92,11 @@ def fill_indices_wrapper( start_index_values, write_offsets, permuted_indices, + # pyrefly: ignore [bad-argument-type] experts_per_rank, + # pyrefly: ignore [bad-argument-type] num_ranks, + # pyrefly: ignore [bad-argument-type] BLOCK_SIZE=block_size, ) return permuted_indices diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 741c908eab..da58c68b03 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -77,20 +77,20 @@ def _run_experts_for_loop( num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() + num_tokens_per_expert_list = num_tokens_per_expert.tolist() # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) + num_padding = x.shape[0] - sum(num_tokens_per_expert_list) # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, + x_splits = torch.split( + x[: sum(num_tokens_per_expert_list)], + split_size_or_sections=num_tokens_per_expert_list, dim=0, ) out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): + for expert_idx, x_expert in enumerate(x_splits): h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) @@ -148,7 +148,9 @@ def forward( # Convert parameters from DTensors to plain Tensors, to work with # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. w1 = self.w1.to_local() + # pyrefly: ignore [missing-attribute] w2 = self.w2.to_local() + # pyrefly: ignore [missing-attribute] w3 = self.w3.to_local() else: w1 = self.w1 @@ -161,6 +163,7 @@ def forward( # otherwise, EP will handle the padding. if ( not isinstance(self.w1, DTensor) + # pyrefly: ignore [not-iterable] or "ep" not in self.w1.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5f9f0a73be..c2eaed8de6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh @@ -121,6 +122,7 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -181,7 +183,9 @@ def parallelize_qwen3( ) # Enable weight tying after applying parallelisms + # pyrefly: ignore [missing-attribute] if model.model_args.enable_weight_tying: + # pyrefly: ignore [missing-attribute] model.output.weight = model.tok_embeddings.weight return model @@ -242,6 +246,7 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -260,6 +265,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -274,8 +280,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 2def3a949a..d0a0556bf1 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -59,7 +59,5 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops(self, model, 2 * self.head_dim, seq_len) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 62b5d0c381..0683b4c42d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -160,6 +160,9 @@ class Attention(nn.Module): """ + q_norm: nn.RMSNorm | None + k_norm: nn.RMSNorm | None + def __init__(self, model_args: Qwen3ModelArgs): super().__init__() self.n_heads = model_args.n_heads @@ -199,8 +202,10 @@ def __init__(self, model_args: Qwen3ModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case "varlen": + # pyrefly: ignore [bad-assignment] self.inner_attention = VarlenAttentionWrapper() case "sdpa": + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() case _: raise ValueError(f"Unknown attention type: {self.attn_type}") @@ -476,6 +481,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -567,11 +573,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.rope_cache, attention_masks, positions) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/qwen3/model/state_dict_adapter.py b/torchtitan/models/qwen3/model/state_dict_adapter.py index 11bb8058c0..8dfe4d5aa7 100644 --- a/torchtitan/models/qwen3/model/state_dict_adapter.py +++ b/torchtitan/models/qwen3/model/state_dict_adapter.py @@ -63,6 +63,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -85,9 +86,12 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() @@ -96,6 +100,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -104,6 +109,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: if key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] if self.model_args.enable_weight_tying and key == "output.weight": continue new_key = to_hf_map[key] @@ -121,6 +127,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} if ( + # pyrefly: ignore [missing-attribute] self.model_args.enable_weight_tying and "lm_head.weight" not in hf_state_dict ): @@ -132,6 +139,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=2) layer_num, expert_num = re.findall(r"\d+", key) titan_abstract_key = self.from_hf_map[abstract_key] + assert titan_abstract_key is not None new_key = titan_abstract_key.format(layer_num) # Store the expert's weight in expert_weights_by_layer for concatenating later. @@ -155,6 +163,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -163,13 +172,16 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] + # pyrefly: ignore [missing-attribute] new_key = new_key.format(layer_num) state_dict[new_key] = value else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index addfa17421..5bf73fbb7e 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -96,12 +96,13 @@ def _caculate_indices_from_placements( dim_size: int, dtensor_placements: tuple, device_mesh: DeviceMesh, - ) -> tuple[int, int]: + ) -> tuple[int | None, int | None]: mesh_names = [] dim_i_placements = [] # Find all the device mesh dimensios that shard on dim-i + # pyrefly: ignore [bad-argument-type] for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] if placement.dim == dim: @@ -181,7 +182,9 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + # pyrefly: ignore [missing-attribute] device_mesh = grouped_expert_weight.device_mesh + # pyrefly: ignore [missing-attribute] dtensor_placements = grouped_expert_weight.placements # Step 1: Extract dimension-0 placement information @@ -212,6 +215,7 @@ def _get_local_experts_weights( elif isinstance(placement, _StridedShard): # Keep strided shard with same parameters new_placements.append( + # pyrefly: ignore [unexpected-positional-argument] _StridedShard(placement.dim, placement.split_factor) ) else: @@ -284,6 +288,7 @@ def _concatenate_expert_weights_dtensor( sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] + # pyrefly: ignore [missing-attribute] local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor assert ( @@ -306,7 +311,7 @@ def _concatenate_expert_weights_dtensor( def _split_experts_weights( self, weight: torch.Tensor, n_experts: int - ) -> list[torch.Tensor]: + ) -> tuple[torch.Tensor, ...]: """ Split the weights of the experts into a list of tensors. Used for offline conversion. @@ -365,7 +370,7 @@ def get_dense_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Args: model_args: BaseModelArgs object containing model configuration parameters. @@ -395,6 +400,7 @@ def get_dense_model_nparams_and_flops( # 4. we follow the convention and do not account for sparsity in causal attention num_flops_per_token = ( 6 * (nparams - nparams_embedding) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) @@ -410,7 +416,7 @@ def get_moe_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Calculate nparams and nflops for MoE models. @@ -450,6 +456,7 @@ def get_moe_model_nparams_and_flops( nparams_sparse_active = ( nparams_moe_router + nparams_shared_experts + # pyrefly: ignore [missing-attribute] + nparams_experts * model_args.moe_args.top_k // model_args.moe_args.num_experts ) @@ -460,6 +467,7 @@ def get_moe_model_nparams_and_flops( num_flops_per_token = ( 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 4cb193c31a..712449f2f6 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -37,9 +37,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: pass @abstractmethod - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: pass diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index e22692bd52..7b2b3ef3ad 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -8,7 +8,7 @@ import os import re from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict from torch.distributed.checkpoint import HuggingFaceStorageReader @@ -27,6 +27,8 @@ class BaseStateDictAdapter(ABC): hf_assets_path: path to HF assets folder containing tokenizer, model weights, etc. """ + fqn_to_index_mapping: Dict[Any, int] | None + @abstractmethod def __init__( self, @@ -98,6 +100,7 @@ def __init__( if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = int(indx) else: diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..5c2b40b217 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -69,6 +69,7 @@ def trace_handler(prof): elif torch.xpu.is_available(): gpu_device_profiled = torch.profiler.ProfilerActivity.XPU with torch.profiler.profile( + # pyrefly: ignore [bad-argument-type] activities=[ torch.profiler.ProfilerActivity.CPU, gpu_device_profiled, diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 0b1c78d0d6..d2fa409223 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -65,7 +65,7 @@ def collect(reason: str, generation: int = 1): # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC -def get_peak_flops(device_name: str) -> int: +def get_peak_flops(device_name: str) -> float: try: # Run the lspci command and capture the output result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) diff --git a/torchtitan/train.py b/torchtitan/train.py index c897ee3c8a..8c597cd608 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -10,10 +10,11 @@ import os import time from datetime import timedelta -from typing import Any, Generator, Iterable +from typing import Any, Iterable import torch +import torch.distributed.checkpoint.stateful from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -60,7 +61,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # runtime utilities device: torch.device gc_handler: utils.GarbageCollection - train_context: Generator[None, None, None] + train_context: dist_utils.TrainContext gradient_accumulation_steps: int pp_has_first_stage: bool pp_has_last_stage: bool @@ -82,8 +83,10 @@ def __init__(self, job_config: JobConfig): importlib.import_module(job_config.experimental.custom_import) device_module, device_type = utils.device_module, utils.device_type + # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. + # pyrefly: ignore [missing-attribute] device_module.set_device(self.device) # init distributed and build meshes @@ -99,6 +102,7 @@ def __init__(self, job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 + # pyrefly: ignore [bad-argument-type] self.ft_manager = FTManager(job_config.fault_tolerance) dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) @@ -149,6 +153,7 @@ def __init__(self, job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) + # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # metrics logging @@ -166,6 +171,7 @@ def __init__(self, job_config: JobConfig): ( model_param_count, self.metrics_processor.num_flops_per_token, + # pyrefly: ignore [bad-argument-type] ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) logger.info( @@ -242,10 +248,12 @@ def __init__(self, job_config: JobConfig): for m in self.model_parts: m.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] m.init_weights(buffer_device=buffer_device) m.train() # confirm that user will be able to view loss metrics on the console + # pyrefly: ignore [bad-argument-type] ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel @@ -253,6 +261,7 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] model.init_weights(buffer_device=buffer_device) model.train() @@ -458,6 +467,7 @@ def post_dataloading_process( attn_type = getattr(self.model_args, "attn_type", "sdpa") if attn_type in ["flex", "varlen"]: + # pyrefly: ignore [not-callable] extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, @@ -486,6 +496,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], + # pyrefly: ignore [bad-argument-type] cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, @@ -556,6 +567,7 @@ def train_step( # If data runs out during gradient accumulation, that # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): + # pyrefly: ignore [no-matching-overload] input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) @@ -636,6 +648,7 @@ def train(self): leaf_folder=leaf_folder, ) as memory_profiler, maybe_semi_sync_training( + # pyrefly: ignore [bad-argument-type] job_config.fault_tolerance, ft_manager=self.ft_manager, model=self.model_parts[0], @@ -652,6 +665,7 @@ def train(self): ), ), ): + # pyrefly: ignore [bad-argument-type] data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 @@ -671,7 +685,9 @@ def train(self): self.job_config.validation.enable and self.validator.should_validate(self.step) ): + # pyrefly: ignore [missing-attribute] with self.loss_fn.no_rescale(): + # pyrefly: ignore [bad-argument-count] self.validator.validate(self.model_parts, self.step) # signal the profiler that the next profiling step has started From 995154f3c3476882c6e55d0a5a667cf5d26c48f3 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 12 Dec 2025 18:21:46 -0800 Subject: [PATCH 063/127] [Autoparallel] Add local_map variant of DSv3 and 2D mesh AP (#2129) Stacked PRs: * __->__#2129 --- --- --- [Autoparallel] Add local_map variant of DSv3 and 2D mesh AP Currently, the AP experiment monkey patches Titan's main DSv3 implementation. But this is prone to breakage from both model definition changes in titan and from HOP/partitioner related changes in core. When these breaks happen, people are usually blocked until I find the root cause. I'm going on PTO for the rest of the year, so I'm adding an integration to AP's DSv3 model in an attempt to make the development more stable for the upcoming PP integration. Test: https://gist.github.com/xmfan/db15fda1e1bc1df7cd523005fe0baf33 --- torchtitan/experiments/__init__.py | 1 + torchtitan/experiments/autoparallel/README.md | 6 + .../deepseek_v3/parallelize_deepseekv3.py | 2 - .../local_map_deepseek_v3/__init__.py | 57 ++++++ .../local_map_deepseek_v3/args.py | 49 +++++ .../local_map_deepseek_v3/model.py | 18 ++ .../parallelize_deepseekv3.py | 182 ++++++++++++++++++ 7 files changed, 313 insertions(+), 2 deletions(-) create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 7d7f4da41a..10f9030c1d 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -15,5 +15,6 @@ "transformers_modeling_backend", "autoparallel.llama3", "autoparallel.deepseek_v3", + "autoparallel.local_map_deepseek_v3", ] ) diff --git a/torchtitan/experiments/autoparallel/README.md b/torchtitan/experiments/autoparallel/README.md index 3be86b9bc3..570237b4d9 100644 --- a/torchtitan/experiments/autoparallel/README.md +++ b/torchtitan/experiments/autoparallel/README.md @@ -17,3 +17,9 @@ Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://githu **DeepSeekv3** `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` + +**DeepSeekv3 local_map** + +This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep. + +`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2` diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index 0f718a389b..80dfcac9a3 100644 --- a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -257,8 +257,6 @@ def set_torchtitan_fields(orig, new): block.moe_enabled = hasattr(block, "moe") -# Run workflow with: -# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_autoparallel def parallelize_deepseekv3( model, parallel_dims: ParallelDims, diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py new file mode 100644 index 0000000000..fdd8435ebc --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .args import DeepSeekV3ModelArgs, get_sample_config + +from .model import DeepSeekV3Model +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_model_args() -> DeepSeekV3ModelArgs: + model_args = copy.deepcopy(deepseekv3_args) + # TODO: Align configs between AP and Titan + for config in model_args.keys(): + # Just override the configs + override = get_sample_config() + override.update_from_config = model_args[config].update_from_config + override.get_nparams_and_flops = model_args[config].get_nparams_and_flops + model_args[config] = override + + return model_args + + +def get_train_spec() -> TrainSpec: + model_args = get_model_args() + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py new file mode 100644 index 0000000000..7f1f84f45a --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs, + MoEArgs as _MoEArgs, +) +from torchtitan.protocols.model import BaseModelArgs + + +# Need to share same base class with torchtitan models +@dataclass +class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs): + pass + + +def get_sample_config() -> DeepSeekV3ModelArgs: + return DeepSeekV3ModelArgs( + vocab_size=2048, + max_seq_len=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=4, + n_dense_layers=0, + n_heads=16, + moe_args=_MoEArgs( + num_experts=4, + num_shared_experts=2, + top_k=2, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=None, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py new file mode 100644 index 0000000000..f4915fb708 --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepSeekV3ModelArgs + + +# Need to share same base class with torchtitan models +class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__(model_args) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..eb400484f6 --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + # Update me when changing dsv3.py + assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" + assert world_mesh.mesh_dim_names == ( + "dp_shard_mod_ep", + "dp_shard_in_ep", + ), "Current setup assumes these specific meshes" + + # Provide AP MoE with mesh + for layer in model.layers.values(): + if layer.moe_enabled: + layer.moe.mesh = world_mesh + layer.moe.axis_name = "dp_shard_in_ep" + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + should_compile = job_config.compile.enable + if should_compile: + # TODO: support more options in AP API + assert job_config.compile.components == ["model"] + assert job_config.compile.backend == "inductor" + + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=should_compile, + dynamic=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Shard(0)) + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + assert not loss_parallel_enabled + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff From 9bc50ea8349864e376b52de366597282d16bb6c4 Mon Sep 17 00:00:00 2001 From: akashveramd Date: Sat, 13 Dec 2025 00:34:35 -0800 Subject: [PATCH 064/127] Implement ciflow/rocm on Torchtitan (#2114) In this PR, I implemented ciflow/rocm on Torchtitan. The changes are part of integration_test_8gpu_features.yaml. The workflow still supports running on pull_request (without any PR label) for CUDA. However, along with push to main and cron schedule, with the ciflow/8gpu label added to PR, the workflow runs for both CUDA & ROCm. --------- Co-authored-by: Huy Do --- .github/labeler.yml | 6 ++ .github/pytorch-probot.yml | 3 + .../integration_test_8gpu_features.yaml | 30 +------- .github/workflows/set-matrix.yaml | 76 +++++++++++++++++++ 4 files changed, 88 insertions(+), 27 deletions(-) create mode 100644 .github/labeler.yml create mode 100644 .github/pytorch-probot.yml create mode 100644 .github/workflows/set-matrix.yaml diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000000..ed5a23bf4e --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,6 @@ +"ciflow/8gpu": + - .ci/docker/** + - .github/workflows/** + - scripts/** + - tests/** + - torchtitan/** diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml new file mode 100644 index 0000000000..9eae404af4 --- /dev/null +++ b/.github/pytorch-probot.yml @@ -0,0 +1,3 @@ +ciflow_push_tags: + - ciflow/8gpu +labeler_config: labeler.yml diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index a20cd22545..e8b2fe63ea 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -3,6 +3,8 @@ name: 8 GPU Feature Tests on: push: branches: [ main ] + tags: + - ciflow/8gpu/* paths-ignore: - 'torchtitan/experiments/**' pull_request: @@ -27,33 +29,7 @@ permissions: jobs: # Step 1: Dynamically compute the matrix based on conditions set-matrix: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set.outputs.matrix }} - steps: - - id: set - run: | - # Decide which matrix entries to include based on event type - if [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]] || [[ "${{ github.event_name }}" == "schedule" ]]; then - # Include both CUDA and ROCm - echo '{"include":[ - {"name":"cuda","runner":"linux.g5.48xlarge.nvidia.gpu","gpu-arch-type":"cuda","gpu-arch-version":"12.6","docker-image":"torchtitan-ubuntu-20.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/cu126"}, - {"name":"rocm","runner":"linux.rocm.gpu.gfx942.8","gpu-arch-type":"rocm","gpu-arch-version":"7.0","docker-image":"torchtitan-rocm-ubuntu-22.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/rocm7.0"} - ]}' > matrix.json - else - # Include only CUDA - echo '{"include":[ - {"name":"cuda","runner":"linux.g5.48xlarge.nvidia.gpu","gpu-arch-type":"cuda","gpu-arch-version":"12.6","docker-image":"torchtitan-ubuntu-20.04-clang12","index-url":"https://download.pytorch.org/whl/nightly/cu126"} - ]}' > matrix.json - fi - - # Export matrix to job outputs - { - echo 'matrix<> $GITHUB_OUTPUT - + uses: ./.github/workflows/set-matrix.yaml # Step 2: Use the dynamic matrix in the build-test job build-test: diff --git a/.github/workflows/set-matrix.yaml b/.github/workflows/set-matrix.yaml new file mode 100644 index 0000000000..5564d8d70b --- /dev/null +++ b/.github/workflows/set-matrix.yaml @@ -0,0 +1,76 @@ +name: Set Matrix + +on: + workflow_call: + outputs: + matrix: + description: dynamically set matrix + value: ${{ jobs.set.outputs.matrix }} + +jobs: + set: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set.outputs.matrix }} + env: + # Event flags evaluated by github actions before the step runs: + IS_MAIN_PUSH: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + IS_SCHEDULE: ${{ github.event_name == 'schedule' }} + IS_8GPU_TAG: ${{ startsWith(github.ref, 'refs/tags/ciflow/8gpu/') }} + TRIGGERED_8GPU_LABEL: ${{ github.event_name == 'pull_request' && github.event.action == 'labeled' }} + + steps: + - id: set + run: | + # Define ROCm matrix + ROCM_MATRIX='{ + "name": "rocm", + "runner": "linux.rocm.gpu.gfx942.8", + "gpu-arch-type": "rocm", + "gpu-arch-version": "7.0", + "docker-image": "torchtitan-rocm-ubuntu-22.04-clang12", + "index-url": "https://download.pytorch.org/whl/nightly/rocm7.0" + }' + + # Define CUDA matrix + CUDA_MATRIX='{ + "name": "cuda", + "runner": "linux.g5.48xlarge.nvidia.gpu", + "gpu-arch-type": "cuda", + "gpu-arch-version": "12.6", + "docker-image": "torchtitan-ubuntu-20.04-clang12", + "index-url": "https://download.pytorch.org/whl/nightly/cu126" + }' + + # Use default value as 'false' for unset environment variables + IS_MAIN_PUSH="${IS_MAIN_PUSH:-false}" + IS_SCHEDULE="${IS_SCHEDULE:-false}" + IS_8GPU_TAG="${IS_8GPU_TAG:-false}" + TRIGGERED_8GPU_LABEL="${TRIGGERED_8GPU_LABEL:-false}" + + # Decide which matrix entries to include based on event type + # Runs ROCm only for push tag OR when PR label gets triggered + if [[ "$IS_8GPU_TAG" == "true" || "$TRIGGERED_8GPU_LABEL" == "true" ]]; then + cat > matrix.json < matrix.json < matrix.json <> $GITHUB_OUTPUT From 2aac20a95f2bbc1dfc2dd54aab22821adf9c3b18 Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Sat, 13 Dec 2025 21:40:16 -0800 Subject: [PATCH 065/127] [MoE] Add node limited routing support (#2111) As titled, added node-limited routing support via two-layer routing. First, group experts into `num_groups` groups, and experts in the same group should reside on the same node to utilize fast intra-node communication. Second, pick the `top_k_group` by the top 2 expert scores' sum in each group. Third, pick `top_k` experts within the selected `top_k_groups`. Reference: https://github.com/huggingface/transformers/blob/4c9fde2a2a3aece0bcf1be93f696e88297da9397/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L212 Test on one node using DeepSeek V3 debug model with MoE arguments `num_experts=8, num_shared_experts=2, num_groups=4, top_k_group=2, top_k=3`. Pasted Graphic --- torchtitan/models/deepseek_v3/__init__.py | 8 +-- torchtitan/models/deepseek_v3/model/args.py | 5 -- torchtitan/models/moe/moe.py | 77 ++++++++++++++++++--- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7e2d35a5d9..31e450eb04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -112,13 +112,13 @@ num_experts=160, num_shared_experts=2, top_k=6, + num_expert_groups=8, + num_limited_groups=3, score_func="softmax", route_norm=False, route_scale=16.0, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=3, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -139,13 +139,13 @@ num_experts=256, num_shared_experts=1, top_k=8, + num_expert_groups=8, + num_limited_groups=4, score_func="sigmoid", route_norm=True, route_scale=2.5, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=4, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 64a9d2bb81..6609e6fa4e 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -37,8 +37,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_heads (int): Number of attention heads. norm_eps (float): Epsilon value used for RMSNorm. moe_args (MoEArgs): MoE configuration. - n_expert_groups (int): Number of expert groups. - n_limited_groups (int): Number of limited groups for MoE routing. q_lora_rank (int): LoRA rank for query projections. kv_lora_rank (int): LoRA rank for key-value projections. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. @@ -66,9 +64,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): # MoE moe_args: MoEArgs = field(default_factory=MoEArgs) - # TODO: node-limited routing is not supported yet - n_expert_groups: int = 1 - n_limited_groups: int = 1 # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index da58c68b03..bde48abaa1 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -26,8 +26,10 @@ class MoEArgs: route_scale: float = 1.0 score_before_experts: bool = True - # token-choice + # token-choice with optional node limited routing top_k: int = 1 + num_expert_groups: int | None = None # must be a divisor of num_experts + num_limited_groups: int | None = None use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 @@ -183,9 +185,17 @@ class TokenChoiceTopKRouter(nn.Module): """This class implements token-choice routing. In token-choice top-K routing, each token is routed to top K experts based on the router scores. + Optionally supports node-limited (group-limited) routing where experts are divided into groups + (e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts. + This reduces cross-node communication in distributed settings. + Args: dim (int): Dimension of input tokens. num_experts (int): Number of experts in each moe layer. + num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard + top-k routing is used. Must be a divisor of num_experts. + num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when + num_expert_groups is set. top_k (int): Number of experts each token will be routed to in token-choice routing. score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. route_norm (bool): Whether to normalize the routing scores when using sigmoid. @@ -196,6 +206,8 @@ def __init__( self, dim: int, num_experts: int, + num_expert_groups: int | None, + num_limited_groups: int | None, top_k: int, score_func: Literal["softmax", "sigmoid"], route_norm: bool, @@ -205,6 +217,8 @@ def __init__( super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts = num_experts + self.num_expert_groups = num_expert_groups + self.num_limited_groups = num_limited_groups self.top_k = top_k self.score_func = score_func self.route_norm = route_norm @@ -228,6 +242,47 @@ def _debug_force_load_balance_routing( top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] return selected_experts_indices, top_scores + def _get_node_limited_routing_scores( + self, + scores_for_choice: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Select num_limited_groups groups based on group scores, + and set expert scores in non-selected groups as -inf + + Args: + scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts) + + Returns: + scores_for_choice: shape (bs*slen, num_experts) + """ + if self.num_limited_groups is None: + raise ValueError( + "num_limited_groups must be set when num_expert_groups is set" + ) + if self.num_experts % self.num_expert_groups != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})" + ) + experts_per_group = self.num_experts // self.num_expert_groups + if experts_per_group < 2: + raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2") + scores_grouped = scores_for_choice.view( + -1, self.num_expert_groups, experts_per_group + ) + top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1) + group_scores = top2_scores_in_group.sum(dim=-1) + _, group_idx = torch.topk( + group_scores, k=self.num_limited_groups, dim=-1, sorted=False + ) + group_mask = torch.ones_like(group_scores, dtype=torch.bool) + group_mask.scatter_(1, group_idx, False) # False = selected groups (keep) + # Mask out experts from non-selected groups + scores_for_choice = scores_grouped.masked_fill( + group_mask.unsqueeze(-1), float("-inf") + ).view(-1, self.num_experts) + + return scores_for_choice + def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -257,18 +312,18 @@ def forward( else: raise NotImplementedError(f"Unknown score function {self.score_func}") + scores_for_choice = scores if expert_bias is None else scores + expert_bias + # Apply node-limited routing if configured + if self.num_expert_groups is not None: + scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice) + _, selected_experts_indices = torch.topk( + scores_for_choice, k=self.top_k, dim=-1, sorted=False + ) + # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - if expert_bias is not None: - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) - else: - top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 - ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) # debug override: balanced round-robin routing if self._debug_force_load_balance: @@ -370,6 +425,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, + num_expert_groups=moe_args.num_expert_groups, + num_limited_groups=moe_args.num_limited_groups, top_k=moe_args.top_k, score_func=moe_args.score_func, route_norm=moe_args.route_norm, From c1f4e9491e224acadd69b8e59f8abd9b0aae7145 Mon Sep 17 00:00:00 2001 From: Salman Chishti Date: Sun, 14 Dec 2025 05:54:31 +0000 Subject: [PATCH 066/127] Upgrade GitHub Actions to latest versions (#2152) ## Summary Upgrade GitHub Actions to their latest versions for improved features, bug fixes, and security updates. ## Changes | Action | Old Version(s) | New Version | Release | Files | |--------|---------------|-------------|---------|-------| | `pypa/gh-action-pypi-publish` | [`release/v1`](https://github.com/pypa/gh-action-pypi-publish/releases/tag/release/v1) | [`v1`](https://github.com/pypa/gh-action-pypi-publish/releases/tag/v1) | [Release](https://github.com/pypa/gh-action-pypi-publish/releases/tag/v1) | release.yml | ## Why upgrade? Keeping GitHub Actions up to date ensures: - **Security**: Latest security patches and fixes - **Features**: Access to new functionality and improvements - **Compatibility**: Better support for current GitHub features - **Performance**: Optimizations and efficiency improvements ### Security Note Actions that were previously pinned to commit SHAs remain pinned to SHAs (updated to the latest release SHA) to maintain the security benefits of immutable references. ### Testing These changes only affect CI/CD workflow configurations and should not impact application functionality. The workflows should be tested by running them on a branch before merging. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 92528856a8..50d51d3f9f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -46,4 +46,4 @@ jobs: path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@v1 From f3748d8e9656049019a85f3e67c898a70c8e0e78 Mon Sep 17 00:00:00 2001 From: Salman Chishti Date: Sun, 14 Dec 2025 05:54:35 +0000 Subject: [PATCH 067/127] Upgrade GitHub Actions for Node 24 compatibility (#2151) ## Summary Upgrade GitHub Actions to their latest versions to ensure compatibility with Node 24, as Node 20 will reach end-of-life in April 2026. ## Changes | Action | Old Version(s) | New Version | Release | Files | |--------|---------------|-------------|---------|-------| | `actions/checkout` | [`v3`](https://github.com/actions/checkout/releases/tag/v3), [`v4`](https://github.com/actions/checkout/releases/tag/v4) | [`v6`](https://github.com/actions/checkout/releases/tag/v6) | [Release](https://github.com/actions/checkout/releases/tag/v6) | docker-builds.yml, release.yml | | `actions/download-artifact` | [`v4`](https://github.com/actions/download-artifact/releases/tag/v4) | [`v7`](https://github.com/actions/download-artifact/releases/tag/v7) | [Release](https://github.com/actions/download-artifact/releases/tag/v7) | release.yml | | `actions/setup-python` | [`v5`](https://github.com/actions/setup-python/releases/tag/v5) | [`v6`](https://github.com/actions/setup-python/releases/tag/v6) | [Release](https://github.com/actions/setup-python/releases/tag/v6) | release.yml | | `actions/upload-artifact` | [`v4`](https://github.com/actions/upload-artifact/releases/tag/v4) | [`v6`](https://github.com/actions/upload-artifact/releases/tag/v6) | [Release](https://github.com/actions/upload-artifact/releases/tag/v6) | release.yml | ## Context Per [GitHub's announcement](https://github.blog/changelog/2025-09-19-deprecation-of-node-20-on-github-actions-runners/), Node 20 is being deprecated and runners will begin using Node 24 by default starting March 4th, 2026. ### Why this matters - **Node 20 EOL**: April 2026 - **Node 24 default**: March 4th, 2026 - **Action**: Update to latest action versions that support Node 24 ### Security Note Actions that were previously pinned to commit SHAs remain pinned to SHAs (updated to the latest release SHA) to maintain the security benefits of immutable references. ### Testing These changes only affect CI/CD workflow configurations and should not impact application functionality. The workflows should be tested by running them on a branch before merging. --- .github/workflows/docker-builds.yml | 2 +- .github/workflows/release.yml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index d5f52824d5..f30da0c223 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -48,7 +48,7 @@ jobs: github-secret: ${{ secrets.GITHUB_TOKEN }} - name: Checkout the repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Setup Linux uses: pytorch/test-infra/.github/actions/setup-linux@main diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 50d51d3f9f..cdde1f92bc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,9 +11,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: "3.x" @@ -23,7 +23,7 @@ jobs: python -m build - name: upload windows dists - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: release-dists path: dist/ @@ -40,7 +40,7 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: release-dists path: dist/ From c283a848742198c5f610bc5842322f53e9f02c36 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Sun, 14 Dec 2025 23:53:42 -0800 Subject: [PATCH 068/127] Improve the loss_compare.sh logic (#2143) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2145 * #2144 * __->__ #2143 1. Accept one "." (meaning the current commit) case to simplify the command line. 2. Ignore the untracked files. --- scripts/loss_compare.py | 175 +++++++++++++++++++++++++--------------- 1 file changed, 109 insertions(+), 66 deletions(-) diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index e9761458a8..a084880e0c 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -174,18 +174,6 @@ def validate_arguments( import_result: str | None, ) -> None: """Validate command line arguments.""" - # Validate commit arguments - if one is ".", both must be "." - if (baseline_commit == "." and test_commit != ".") or ( - baseline_commit != "." and test_commit == "." - ): - log_print("Error: If one commit is '.', both commits must be '.'") - log_print(f" Got baseline: '{baseline_commit}', test: '{test_commit}'") - log_print( - " Use '.' for both commits to compare different " - "configurations on current working directory" - ) - sys.exit(1) - # Validate that we are comparing different settings commits_differ = baseline_commit != test_commit configs_differ = baseline_config != test_config @@ -336,7 +324,8 @@ def print_configuration( def check_git_clean_state() -> None: """Check if git working directory is clean before switching commits. - Raises SystemExit if there are uncommitted changes or untracked files. + Raises SystemExit if there are uncommitted changes to tracked files. + Untracked files are ignored. """ result = subprocess.run( ["git", "status", "--porcelain"], @@ -345,12 +334,20 @@ def check_git_clean_state() -> None: check=True, ) - if result.stdout.strip(): - log_print("Error: Git working directory is not clean") + # Filter out untracked files (lines starting with "??") + modified_tracked_files = [] + for line in result.stdout.strip().split("\n"): + if line and not line.startswith("??"): + modified_tracked_files.append(line) + + if modified_tracked_files: + log_print( + "Error: Git working directory has uncommitted changes to tracked files" + ) log_print(" Cannot switch commits with uncommitted changes") log_print("") - log_print("Modified/untracked files:") - for line in result.stdout.strip().split("\n"): + log_print("Modified tracked files:") + for line in modified_tracked_files: log_print(f" {line}") log_print("") log_print( @@ -371,6 +368,39 @@ def checkout_commit(commit: str, commit_name: str) -> None: log_print(f"Using current working directory for {commit_name} (commit: '.')") +def get_current_commit() -> str: + """Get the current git commit hash or branch name. + + Returns the current branch name if on a branch, otherwise returns the commit hash. + """ + # Try to get current branch name + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + ref = result.stdout.strip() + + # If in detached HEAD state, ref will be "HEAD", so get the commit hash instead + if ref == "HEAD": + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + ref = result.stdout.strip() + + return ref + + +def restore_original_commit(original_commit: str) -> None: + """Restore the original git commit/branch.""" + log_print(f"Restoring original commit/branch: {original_commit}") + subprocess.run(["git", "checkout", original_commit], check=True) + + # ============================================================================= # TRAINING OPERATIONS # ============================================================================= @@ -1003,61 +1033,74 @@ def main() -> None: ) # Check if git working directory is clean before switching commits - # Skip check if both commits are "." (comparing configs on same commit) + # Skip check only if both commits are "." (comparing configs on same commit) needs_git_checkout = args.baseline_commit != "." or args.test_commit != "." if needs_git_checkout: check_git_clean_state() - create_seed_checkpoint( - enable_seed_checkpoint, - args.baseline_config, - args.baseline_train_file, - args.output_folder, - args.job_dump_folder, - ) - # Run baseline and test training - baseline_log = run_scenario( - "baseline", - args.baseline_commit, - args.baseline_config, - args.baseline_train_file, - args.baseline_options, - args.steps, - enable_seed_checkpoint, - args.output_folder, - args.job_dump_folder, - args.baseline_ngpus, - ) + # Save original commit if we're going to do checkouts + original_commit = None + if needs_git_checkout: + original_commit = get_current_commit() + log_print(f"Saving original commit/branch: {original_commit}") + log_print() - test_log = run_scenario( - "test", - args.test_commit, - args.test_config, - args.test_train_file, - args.test_options, - args.steps, - enable_seed_checkpoint, - args.output_folder, - args.job_dump_folder, - args.test_ngpus, - ) - log_print() + try: + create_seed_checkpoint( + enable_seed_checkpoint, + args.baseline_config, + args.baseline_train_file, + args.output_folder, + args.job_dump_folder, + ) + # Run baseline and test training + baseline_log = run_scenario( + "baseline", + args.baseline_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.baseline_ngpus, + ) + + test_log = run_scenario( + "test", + args.test_commit, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.test_ngpus, + ) + log_print() - # Assert losses are equal if requested - if args.assert_equal: - # Pass import_result if provided for 3-way comparison - assert_losses_equal(baseline_log, test_log, args.import_result) - - # Export losses if requested (only after assertion passes) - if args.export_result: - # Extract baseline losses (they equal test losses since assertion passed) - baseline_losses = extract_losses_from_log(baseline_log) - export_losses_to_file(baseline_losses, args.export_result) - - # Analysis and reporting - perform_loss_analysis(baseline_log, test_log, stats_file) - cleanup_temp_files(args.output_folder) - print_completion_summary(args.output_folder, enable_seed_checkpoint) + # Assert losses are equal if requested + if args.assert_equal: + # Pass import_result if provided for 3-way comparison + assert_losses_equal(baseline_log, test_log, args.import_result) + + # Export losses if requested (only after assertion passes) + if args.export_result: + # Extract baseline losses (they equal test losses since assertion passed) + baseline_losses = extract_losses_from_log(baseline_log) + export_losses_to_file(baseline_losses, args.export_result) + + # Analysis and reporting + perform_loss_analysis(baseline_log, test_log, stats_file) + cleanup_temp_files(args.output_folder) + print_completion_summary(args.output_folder, enable_seed_checkpoint) + finally: + # Restore original commit if we did checkouts + if original_commit is not None: + log_print() + restore_original_commit(original_commit) if __name__ == "__main__": From 64997d21d030cdecb6853ef82a1e1b8d44d29ab6 Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Mon, 15 Dec 2025 10:40:53 -0800 Subject: [PATCH 069/127] [GPT-OSS] Add HF state dict adapter to support loading from HF checkpoints (#2021) As titled, this PR adds HF state dict adapter to support loading from GPT-OSS HF checkpoint. GPT-OSS checkpoint is quantized in MXPF4 format. The de-quantization steps are offloaded to the `QuantizedHuggingFaceStorageReader` in `dcp`, so this feature depends on this PR to update `QuantizedHuggingFaceStorageReader` (https://github.com/pytorch/pytorch/pull/167672). 1. Test 1. We use `dcp.load(hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader(path=input_dir))` to load from GPT-OSS HF checkpoint, and map the `hf_state_dict` back to TorchTitan state dict. We build one test input, and compare two outputs: 1. Using `transformer` library to load GPT-OSS HF checkpoint and run inference on the test input; 2. We use the converted TorchTitan model to run inference on the test input. We compare the outputs by comparing the KL divergence of two output probability distributions. The result shows two models are very similar. Pasted
Graphic 2. Test 2. We load the model directly from quantized GPT-OSS HF checkpoint, and do a test training. Pasted Graphic 1 --- docs/checkpoint.md | 2 +- torchtitan/experiments/gpt_oss/__init__.py | 2 + .../gpt_oss/model/state_dict_adapter.py | 115 ++++++++++++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 torchtitan/experiments/gpt_oss/model/state_dict_adapter.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 6e3112309b..8aca58eb06 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable --c ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index c12ad13a5c..0ebc20645f 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -16,6 +16,7 @@ from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs from .model.model import GptOssModel +from .model.state_dict_adapter import GptOssStateDictAdapter __all__ = [ "parallelize_gptoss", @@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=GptOssStateDictAdapter, ) diff --git a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py new file mode 100644 index 0000000000..ca85789baf --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import Any + +from torch.distributed.checkpoint import HuggingFaceStorageReader +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import GptOssModelArgs + + +class GptOssStateDictAdapter(MoEStateDictAdapter): + def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention module + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", + "model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks", + # Transformer layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE + "model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight", + "model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias", + "model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight", + "model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias", + "model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where + # expert weights are saved in MXFP4 format. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + return QuantizedHuggingFaceStorageReader( + path=path, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from a tt model state dict to a hf format state dict. + + Only map keys without changing shapes to the same as MXFP4 checkpoint. + For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader + will handle dequantization during load. + + Warning: Conversion does not support saving to mxfp4 quantization format. + One can save into unquantized hf checkpoints with last_save_in_hf = true. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in to_hf_map: + continue + layer_num = re.search(r"\d+", key).group(0) + hf_key = to_hf_map[abstract_key] + hf_key = hf_key.format(layer_num) + hf_state_dict[hf_key] = value + else: + if key not in to_hf_map: + continue + hf_key = to_hf_map[key] + hf_state_dict[hf_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from hf format state dict to tt model state dict. + """ + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + layer_num = re.search(r"\d+", key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + state_dict[tt_key] = value + else: + tt_key = self.from_hf_map[key] + state_dict[tt_key] = value + + return state_dict From c08fa570a623d1823f6669efc66a9759444e38e3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 15 Dec 2025 12:47:27 -0800 Subject: [PATCH 070/127] Add local built pytorch path for pyrefly (#2155) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2156 * __->__ #2155 This assumes that the local built version has the same parent folder as torchtitan. Also fixes some pyrefly errors for moe.py --- pyproject.toml | 1 + torchtitan/models/moe/moe.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7a3687590c..aa5a93fd7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,3 +68,4 @@ testpaths = ["tests"] [tool.pyrefly] project-excludes = ["torchtitan/experiments", "**/tests/**"] ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies +search-path = ["../pytorch"] # local built pytorch diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index bde48abaa1..c5dd59ab29 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -245,7 +245,7 @@ def _debug_force_load_balance_routing( def _get_node_limited_routing_scores( self, scores_for_choice: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """Select num_limited_groups groups based on group scores, and set expert scores in non-selected groups as -inf @@ -259,6 +259,7 @@ def _get_node_limited_routing_scores( raise ValueError( "num_limited_groups must be set when num_expert_groups is set" ) + assert self.num_expert_groups is not None if self.num_experts % self.num_expert_groups != 0: raise ValueError( f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})" From e36d02752791327b6ba3b45cef3951c24bfd6512 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Dec 2025 21:00:02 -0800 Subject: [PATCH 071/127] Run vLLM inference using torchtitan model definition (single GPU) (#2119) As titled, put it in deterministic RL folder --- torchtitan/experiments/README.md | 3 +- torchtitan/experiments/rl/README.md | 12 + torchtitan/experiments/rl/unified/README.md | 68 ++++ torchtitan/experiments/rl/unified/__init__.py | 93 +++++ .../experiments/rl/unified/attention.py | 93 +++++ torchtitan/experiments/rl/unified/infer.py | 115 ++++++ torchtitan/experiments/rl/unified/utils.py | 63 ++++ .../experiments/rl/unified/vllm_wrapper.py | 329 ++++++++++++++++++ .../vllm_compat}/README.md | 8 +- .../vllm_compat}/__init__.py | 15 +- .../vllm_compat}/batch_invariant_backward.py | 0 .../vllm_compat}/models/__init__.py | 7 +- .../vllm_compat}/models/attention.py | 5 +- .../vllm_compat}/models/qwen3/__init__.py | 0 .../models/qwen3/model_vllm_compat.py | 2 +- .../vllm_compat}/simple_rl.py | 16 +- .../vllm_compat}/tests/__init__.py | 0 .../tests/test_batch_invariant_backward.py | 4 +- .../tests/test_exact_determinism.py | 4 +- .../vllm_compat}/weights/README.md | 0 .../vllm_compat}/weights/__init__.py | 0 .../vllm_compat}/weights/converter.py | 0 .../vllm_compat}/weights_vllm_compat.py | 0 23 files changed, 805 insertions(+), 32 deletions(-) create mode 100644 torchtitan/experiments/rl/README.md create mode 100644 torchtitan/experiments/rl/unified/README.md create mode 100644 torchtitan/experiments/rl/unified/__init__.py create mode 100644 torchtitan/experiments/rl/unified/attention.py create mode 100755 torchtitan/experiments/rl/unified/infer.py create mode 100644 torchtitan/experiments/rl/unified/utils.py create mode 100644 torchtitan/experiments/rl/unified/vllm_wrapper.py rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/README.md (97%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/__init__.py (53%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/batch_invariant_backward.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/models/__init__.py (74%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/models/attention.py (98%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/models/qwen3/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/models/qwen3/model_vllm_compat.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/simple_rl.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/tests/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/tests/test_batch_invariant_backward.py (97%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/tests/test_exact_determinism.py (98%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/weights/README.md (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/weights/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/weights/converter.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl/vllm_compat}/weights_vllm_compat.py (100%) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 10b90ac1d4..53df45dd84 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -29,7 +29,8 @@ We provide this `experiments/` folder to host experiments that add significant v | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | | [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | -| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | +| [gpt_oss](./gpt_oss/) | TBA | [@wwwjn](https://github.com/wwwjn) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti) [@wwwjn](https://github.com/wwwjn) | | [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/rl/README.md b/torchtitan/experiments/rl/README.md new file mode 100644 index 0000000000..72b3d2ad11 --- /dev/null +++ b/torchtitan/experiments/rl/README.md @@ -0,0 +1,12 @@ +# Deterministic RL Training with vLLM + +This package provides two approaches for integrating TorchTitan models with vLLM: + +1. vllm_compat/ - vLLM-Compatible approach + - Separate model definition matching vLLM's weight format + - Support batch-invariant and bit-wise identity between train and inference + - Custom backward passes for attention gradient computation + +2. unified/ - Unified approach + - Uses canonical TorchTitan model definition for inference directly + - Replaces attention with vLLM Compatible attention for inference diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md new file mode 100644 index 0000000000..5cea3918ae --- /dev/null +++ b/torchtitan/experiments/rl/unified/README.md @@ -0,0 +1,68 @@ +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now. + +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +uv pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. + + +4. Run inference: +``` +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` + +## TODO +1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. +3. Leverage batch-invariant kernels into model definition. diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py new file mode 100644 index 0000000000..6c34556112 --- /dev/null +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified approach for running TorchTitan models with vLLM inference. + +This module automatically registers TorchTitan models with vLLM when imported. +Uses the canonical TorchTitan model definition directly with vLLM inference engine. +""" + +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec +from vllm.logger import init_logger + +from .utils import create_parallel_dims_from_vllm_config +from .vllm_wrapper import TorchTitanVLLMModelWrapper + + +logger = init_logger(__name__) + + +def register_torchtitan_model_from_train_spec( + train_spec: TrainSpec, + model_name: str, + model_flavor: str, +) -> None: + """ + Register a TorchTitan model with vLLM using a TrainSpec. + + Args: + train_spec: TorchTitan TrainSpec containing model components + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args + + """ + from vllm.model_executor.models.registry import ModelRegistry + + # Get model_args directly from TrainSpec.model_args dict using flavor key + if isinstance(train_spec.model_args, dict): + if model_flavor not in train_spec.model_args: + raise ValueError( + f"Model flavor '{model_flavor}' not found in train_spec.model_args. " + f"Available flavors: {list(train_spec.model_args.keys())}" + ) + model_args = train_spec.model_args[model_flavor] + else: + raise ValueError( + "train_spec.model_args must be a dict mapping flavor names to ModelArgs" + ) + + # Create dynamic model class directly from TrainSpec components + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=train_spec.model_cls, + model_args=model_args, + state_dict_adapter=train_spec.state_dict_adapter, + parallelize_fn=train_spec.parallelize_fn, + vllm_config=vllm_config, + prefix=prefix, + ) + + # Set the class name + TorchTitanVLLMModelFromSpec.__name__ = model_name + TorchTitanVLLMModelFromSpec.__qualname__ = model_name + + # Register with vLLM + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) + + logger.info( + f"Successfully registered {model_name} with vLLM using TrainSpec " + f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})" + ) + + +# Auto-register TorchTitan models with vLLM when this module is imported +register_torchtitan_model_from_train_spec( + train_spec=get_train_spec("qwen3"), + model_name="Qwen3TorchTitanForCausalLM", + # TODO: Remove the model_flavor args when registering model, + # allow passing model flavor option from config system. Now we have to specify + # model_flavor during registration because we can not pass torchtitan job_config from LLM() Api + model_flavor="0.6B", +) + + +__all__ = [ + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", +] diff --git a/torchtitan/experiments/rl/unified/attention.py b/torchtitan/experiments/rl/unified/attention.py new file mode 100644 index 0000000000..1a03b882cb --- /dev/null +++ b/torchtitan/experiments/rl/unified/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from vllm.attention.layer import Attention + + +class VLLMAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention. Compatible with TorchTitan input shape. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_name: str, + scale: float | None = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.layer_name = layer_name + + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + cache_config = ( + vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None + ) + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=f"model.layers.{layer_name}.attention.inner_attention", + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer for inference. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override (unused, vLLM uses internal scale) + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape + + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output_varlen = self.vllm_attn(q, k, v) + + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py new file mode 100755 index 0000000000..19770ecc22 --- /dev/null +++ b/torchtitan/experiments/rl/unified/infer.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run TorchTitan model inference with vLLM Engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model_ckpt_path", + type=str, + default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint", + help="Path to TorchTitan checkpoint directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="Hello, my name is", + help="Prompt text for generation", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + logger.info("Initializing vLLM with TorchTitan model") + logger.info(f"Model: {args.model_ckpt_path}") + logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") + + # Initialize vLLM with custom TorchTitan model + # The LLM initialization will internally: + # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) + # 2. Create TorchTitanVLLMModel instance + # 3. Create JobConfig and ParallelDims from vLLM config + # 4. Apply parallelization using parallelize_qwen3 + # 5. Load model weights and prepare for inference + logger.info("Creating vLLM LLM engine...") + + llm = LLM( + model=args.model_ckpt_path, # Model checkpoint path + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, + dtype="bfloat16", + trust_remote_code=True, + enforce_eager=True, # Use eager mode + tensor_parallel_size=args.tensor_parallel_size, + ) + + logger.info("vLLM engine initialized successfully") + logger.info(f"Prompt: {args.prompt}") + + # Prepare prompt and sampling parameters + prompts = [args.prompt] + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=0.95, + max_tokens=args.max_tokens, + ) + + # Generate text + logger.info("Generating text...") + outputs = llm.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Print results + logger.info("Generation complete") + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/rl/unified/utils.py b/torchtitan/experiments/rl/unified/utils.py new file mode 100644 index 0000000000..e997c387d9 --- /dev/null +++ b/torchtitan/experiments/rl/unified/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for vLLM + TorchTitan models. + +This module provides functions for setting up device mesh and applying +tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. +""" + +import torch.distributed as dist + +from torchtitan.distributed.parallel_dims import ParallelDims +from vllm.config import VllmConfig +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: + """ + Create ParallelDims from vLLM config and maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + + This function is needed because vLLM doesn't separate model creation and + parallelism application - it requires parallelization to be done inside + the model constructor, so we are creating parallel_dims and apply parallelism + in TorchTitanVLLMModelWrapper.__init__ function. + + Args: + vllm_config: vLLM configuration object + + Returns: + ParallelDims object with parallelism settings validated + + Note: + vLLM doesn't use FSDP sharding (dp_shard=1) or expert parallelism (ep=1, etp=1) + in inference. These are set to default values. + """ + world_size = dist.get_world_size() + + # Map vLLM config to TorchTitan ParallelDims + parallel_dims = ParallelDims( + dp_replicate=vllm_config.parallel_config.data_parallel_size, + dp_shard=1, # vLLM doesn't use FSDP sharding + cp=vllm_config.parallel_config.decode_context_parallel_size, + tp=vllm_config.parallel_config.tensor_parallel_size, + pp=vllm_config.parallel_config.pipeline_parallel_size, + ep=1, # Expert parallelism not used in vLLM inference yet + etp=1, # Expert tensor parallelism not used in vLLM inference yet + world_size=world_size, + ) + + logger.info( + f"Created ParallelDims from vLLM config: " + f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" + ) + + return parallel_dims diff --git a/torchtitan/experiments/rl/unified/vllm_wrapper.py b/torchtitan/experiments/rl/unified/vllm_wrapper.py new file mode 100644 index 0000000000..e92903c744 --- /dev/null +++ b/torchtitan/experiments/rl/unified/vllm_wrapper.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base wrapper for TorchTitan models to work with vLLM V1 engine. + +This module provides TorchTitanVLLMModel: Core model class that adapts +TorchTitan models for vLLM. +""" + +from functools import partial + +import torch +import torch.nn as nn +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, +) + +from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention +from torchtitan.models.qwen3.model.model import precompute_rope_cache +from torchtitan.protocols.model import BaseModelArgs, ModelProtocol +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter +from torchtitan.protocols.train_spec import ParallelizeFunction + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from .utils import create_parallel_dims_from_vllm_config + + +logger = init_logger(__name__) + + +class TorchTitanVLLMModelWrapper(nn.Module): + """ + Generic vLLM-compatible model wrapper for TorchTitan models. + + The wrapper handles: + - Direct usage of TorchTitan model args (no HF config mapping needed) + - Attention replacement with vLLM paged attention + - Tensor parallelism setup + - Weight loading from HF checkpoints + - vLLM forward/compute_logits interface + """ + + is_text_generation_model = True # Required for vLLM runner validation + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + + def __init__( + self, + *, + model_cls: type[ModelProtocol], + model_args: BaseModelArgs, + state_dict_adapter: type[BaseStateDictAdapter], + parallelize_fn: ParallelizeFunction, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + assert vllm_config is not None, "vllm_config is required" + + # Store components + self.model_cls = model_cls + self.state_dict_adapter = state_dict_adapter + self.parallelize_fn = parallelize_fn + + # Use TorchTitan model args directly (no HF config mapping) + self.config = model_args + logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") + self.model = self.model_cls(model_args) + + # Setup RoPE cache extension function if provided + self.rope_cache_extension_fn = partial( + precompute_rope_cache, + dim=self.config.head_dim, + base=self.config.rope_theta, + ) + # Replace attention with vLLM paged attention + self._replace_with_vllm_attention(model_args) + + # Create ParallelDims from vLLM config and apply parallelization + # NOTE: We need to apply parallelize within model.__init__ because w + parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + if parallel_dims.tp_enabled: + self.world_mesh = parallel_dims.world_mesh + tp_mesh = self.world_mesh["tp"] + parallelize_fn( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + logger.info( + f"Successfully initialized model with with TP={parallel_dims.tp}" + ) + else: + logger.info("Single GPU mode - no parallelization needed") + + def _replace_with_vllm_attention(self, model_args): + """ + Replace TorchTitan attention with vLLM paged attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + Override in subclass if different structure. + """ + assert hasattr( + self.model, "layers" + ), f"Model {type(self.model).__name__} must have .layers attribute" + + for layer_name, layer in self.model.layers.items(): + assert hasattr( + layer, "attention" + ), f"Layer {layer_name} must have .attention attribute" + + vllm_attn = VLLMAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) + head_dim=model_args.head_dim, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, + ) + + # Replace inner attention + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(self.model.layers)} layers)" + ) + + def _extend_rope_cache_if_needed( + self, rope_cache: torch.Tensor, max_position: int + ) -> torch.Tensor: + """ + Extend RoPE cache if needed during vLLM profiling stage. + + Args: + rope_cache: Current RoPE cache tensor + max_position: Maximum position index needed + + Returns: + Extended RoPE cache if needed, otherwise original cache + """ + from torch.distributed._tensor import DTensor, Replicate + + required_len = max_position + 1 + + # No extension needed + if required_len <= rope_cache.shape[0]: + return rope_cache + + # If no extension function provided, return original cache + if self.rope_cache_extension_fn is None: + logger.warning( + f"RoPE cache extension needed (required_len={required_len}, " + f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " + "Returning original cache." + ) + return rope_cache + + # Handle DTensor case + is_dtensor = isinstance(rope_cache, DTensor) + if is_dtensor: + device_mesh = rope_cache.device_mesh + local_rope_cache = rope_cache.to_local() + device = local_rope_cache.device + dtype = local_rope_cache.dtype + else: + device = rope_cache.device + dtype = rope_cache.dtype + + # Use provided extension function + try: + extended_cache = self.rope_cache_extension_fn(self.config, required_len) + extended_cache = extended_cache.to(device=device, dtype=dtype) + except Exception as e: + logger.warning( + f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " + "Returning original cache." + ) + return rope_cache + + # Convert back to DTensor if needed + if is_dtensor: + rope_cache = DTensor.from_local( + extended_cache, + device_mesh=device_mesh, + placements=[Replicate()], + ) + else: + rope_cache = extended_cache + + return rope_cache + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings.""" + return self.model.tok_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs [total_tokens] (1D varlen format) + positions: Position indices [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states [total_tokens, hidden_size] + """ + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds not yet supported") + + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Convert vLLM interface to TorchTitan interface + # vLLM: [total_tokens] → TorchTitan: [batch_size, seq_len] + tokens_2d = input_ids.unsqueeze(0) + + # Get embeddings + h = self.model.tok_embeddings(tokens_2d) + + # Get RoPE cache (handle model-specific attribute names) + # Use hasattr to avoid ambiguous boolean value error with tensors + if hasattr(self.model, "rope_cache"): + rope_attr = self.model.rope_cache + elif hasattr(self.model, "freqs_cis"): + rope_attr = self.model.freqs_cis + else: + rope_attr = None + + # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) + if positions is not None: + max_position = positions.max().item() + else: + max_position = 0 + + rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) + positions = positions.unsqueeze(0) + + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None, positions=positions) + + # Convert to vLLM format: [total_tokens, hidden_size] + if h.dim() == 3: + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) + + return h + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + h = self.model.norm(hidden_states) + logits = self.model.output(h) + + return logits + + def load_weights(self, weights_iter): + """ + Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + # Collect weights from iterator + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor + + # Use adapter to convert HF → TorchTitan format + adapter = self.state_dict_adapter( + model_args=self.config, + hf_assets_path=None, + ) + + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert to DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], + ) + + # Load state dict + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions(strict=False), + ) + + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + + return loaded_params diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/rl/vllm_compat/README.md similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/README.md rename to torchtitan/experiments/rl/vllm_compat/README.md index d2ef719c0d..bf56f4afbe 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -77,7 +77,7 @@ init_batch_invariance() ```python import torch from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( +from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, Qwen3VLLMCompatModel, ) @@ -111,7 +111,7 @@ loss.backward() Run the RL training loop: ```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.rl.vllm_compat.simple_rl ``` This will: @@ -177,7 +177,7 @@ assert torch.equal(vllm_logprobs, titan_logprobs) Run the test suite: ```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests +cd torchtitan/experiments/rl/vllm_compat/tests # Test backward passes python test_batch_invariant_backward.py @@ -214,7 +214,7 @@ This implementation uses the same kernels for both rollouts (vLLM) and training ## Project Structure ``` -deterministic_vllm_rl/ +rl/vllm_compat/ ├── README.md # Documentation ├── __init__.py # Package initialization ├── batch_invariant_backward.py # Backward passes for vLLM ops diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/rl/vllm_compat/__init__.py similarity index 53% rename from torchtitan/experiments/deterministic_vllm_rl/__init__.py rename to torchtitan/experiments/rl/vllm_compat/__init__.py index 067555251f..b86721fba5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/__init__.py @@ -5,16 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Deterministic RL training with vLLM experiment. +vLLM-Compatible approach for deterministic RL training. -This experiment provides tools for bitwise-deterministic reinforcement learning -training using vLLM for fast rollouts and TorchTitan for training. - -Key components: -- VLLMCompatibleFlashAttention: Flash attention with custom backward pass -- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections -- batch_invariant_backward: Gradient support for vLLM's deterministic operations -- simple_rl: End-to-end RL training loop +This module provides models that match vLLM's weight format (e.g., merged gate_up_proj) +with custom backward passes for gradient computation during training. """ from .batch_invariant_backward import ( @@ -22,9 +16,10 @@ rms_norm_with_gradients, silu_and_mul_with_gradients, ) -from .models import VLLMCompatibleFlashAttention +from .models.attention import VLLMCompatibleFlashAttention from .models.qwen3 import Qwen3VLLMCompatModel + __all__ = [ "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/__init__.py similarity index 74% rename from torchtitan/experiments/deterministic_vllm_rl/models/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/__init__.py index c8c11a170a..2e7a5fa6af 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/models/__init__.py @@ -6,8 +6,13 @@ """ Models for deterministic vLLM RL training. + +This module provides vLLM-compatible model components. """ from .attention import VLLMCompatibleFlashAttention -__all__ = ["VLLMCompatibleFlashAttention"] + +__all__ = [ + "VLLMCompatibleFlashAttention", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/models/attention.py rename to torchtitan/experiments/rl/vllm_compat/models/attention.py index 33dd5a140d..11e6d3af67 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -4,12 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. -""" import torch -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.attention.utils.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py index dd84665091..2c9742b1fa 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/simple_rl.py rename to torchtitan/experiments/rl/vllm_compat/simple_rl.py index ffc7d52eb0..508868c0d4 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -25,20 +25,20 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance() @@ -340,7 +340,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -1058,7 +1058,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py rename to torchtitan/experiments/rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py index 3ed9604d10..ddf8b01514 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py @@ -8,9 +8,11 @@ Test batch_invariant_backward module to ensure it works correctly. """ +import sys + import torch -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( disable_batch_invariant_backward_mode, enable_batch_invariant_backward_mode, linear_batch_invariant_backward, diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py index 8d0ac3133e..2a9863ab2f 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py @@ -11,11 +11,11 @@ """ import torch -from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) +from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode print("Enabling batch_invariant_backward mode...") disable_batch_invariant_mode() diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/README.md rename to torchtitan/experiments/rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py rename to torchtitan/experiments/rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/converter.py rename to torchtitan/experiments/rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py From f64bbad4e46e704fd2647921e6cb0c6f33b82fd5 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 15 Dec 2025 22:19:34 -0800 Subject: [PATCH 072/127] [RELAND] Let CUDA and ROCm read different loss result (#2157) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2157 CUDA and ROCm have different loss results. So we need to read from different loss result files. The loss results of FSDP and HSDP start to diverge after 5th step when running with ROCm, we also need to adjust this. But this is more an unknown issue that AMD people may want to figure out the root cause or confirm that this is an expected behavior. **This PR is a reland PR of https://github.com/pytorch/torchtitan/pull/2156** due to some landing issue of the previous PR. --- .../integration_test_8gpu_features.yaml | 20 ++++++++++++++++++- .../losses/{llama3.txt => llama3_cuda.txt} | 0 tests/assets/losses/llama3_rocm.txt | 5 +++++ 3 files changed, 24 insertions(+), 1 deletion(-) rename tests/assets/losses/{llama3.txt => llama3_cuda.txt} (100%) create mode 100644 tests/assets/losses/llama3_rocm.txt diff --git a/.github/workflows/integration_test_8gpu_features.yaml b/.github/workflows/integration_test_8gpu_features.yaml index e8b2fe63ea..de708f3cd5 100644 --- a/.github/workflows/integration_test_8gpu_features.yaml +++ b/.github/workflows/integration_test_8gpu_features.yaml @@ -70,7 +70,25 @@ jobs: echo "Checking FSDP8 v.s. HSDP (4, 2) accuracy parity" export baseline_options="--parallelism.data_parallel_replicate_degree=1" export test_options="--parallelism.data_parallel_replicate_degree=4" - python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=10 --import-result tests/assets/losses/llama3.txt + + # Set architecture-specific parameters + if [[ "${{ matrix.gpu-arch-type }}" == "cuda" ]]; then + LOSS_FILE="tests/assets/losses/llama3_cuda.txt" + STEPS=10 + elif [[ "${{ matrix.gpu-arch-type }}" == "rocm" ]]; then + # The loss results of FSDP and HSDP start to diverge after 5th + # step when running with ROCm, we also need to adjust this. + # But this is more an unknown issue that AMD people may want to + # figure out the root cause or confirm that this is an expected + # behavior. + LOSS_FILE="tests/assets/losses/llama3_rocm.txt" + STEPS=5 + else + echo "Error: Unknown GPU architecture type: ${{ matrix.gpu-arch-type }}" + exit 1 + fi + + python3 scripts/loss_compare.py . . --baseline-options="${baseline_options}" --test-options="${test_options}" --job-dump-folder="${RUNNER_TEMP}/artifacts-to-be-uploaded/accuracy_comparison_outputs" --assert-equal --steps=${STEPS} --import-result ${LOSS_FILE} rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/* python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite features $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 diff --git a/tests/assets/losses/llama3.txt b/tests/assets/losses/llama3_cuda.txt similarity index 100% rename from tests/assets/losses/llama3.txt rename to tests/assets/losses/llama3_cuda.txt diff --git a/tests/assets/losses/llama3_rocm.txt b/tests/assets/losses/llama3_rocm.txt new file mode 100644 index 0000000000..3aa7c24a1d --- /dev/null +++ b/tests/assets/losses/llama3_rocm.txt @@ -0,0 +1,5 @@ +1 8.1376 +2 7.8409 +3 7.1815 +4 6.3509 +5 5.7090 From 183a0d2e7241747537fc8f98af056519aff33eab Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 17 Dec 2025 14:42:39 -0800 Subject: [PATCH 073/127] Use new DeviceMesh unflatten to rewrite parallel_dims (#1660) **Summary** This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes. The design philosophy is as follow: 1. Create one world mesh with the shape as [world_size,] 2. Create all 1-D submeshes by using 1) unflattening from the world mesh, or 2) slicing and flatten from other derived meshes. 3. ParallelDims now provides an API, get_mesh() and get_optional_mesh(). which accepts str or list[str]. When the argument is str, the API directly return the corresponding 1-D submesh. If the argument is list[str], the dim names will be used to concatenate to form a n-D device mesh. The main difference between the two APIs is that the former one will raise an ValueError if the result mesh is None the later one will just return None. --- scripts/generate/test_generate.py | 18 +- tests/unit_tests/test_parallel_dims.py | 569 ++++++++++++++++++ tests/unit_tests/test_set_determinism.py | 76 ++- torchtitan/components/optimizer.py | 8 +- torchtitan/components/validate.py | 13 +- torchtitan/config/job_config.py | 14 +- torchtitan/distributed/expert_parallel.py | 8 +- torchtitan/distributed/parallel_dims.py | 365 +++++++---- torchtitan/distributed/pipeline_parallel.py | 2 +- torchtitan/distributed/utils.py | 77 ++- .../deepseek_v3/parallelize_deepseekv3.py | 33 +- .../autoparallel/llama3/parallelize_llama.py | 22 +- .../compiler_toolkit/common_utils.py | 6 +- .../compiler_toolkit/graph_utils.py | 4 +- torchtitan/experiments/forge/engine.py | 7 +- torchtitan/experiments/forge/example_train.py | 12 +- .../experiments/gpt_oss/infra/parallelize.py | 58 +- .../simple_fsdp/deepseek_v3/parallelize.py | 49 +- .../simple_fsdp/llama3/parallelize.py | 10 +- .../simple_fsdp/tests/test_numerics.py | 13 +- .../infra/parallelize.py | 18 +- .../infra/pipeline.py | 2 +- torchtitan/experiments/vlm/infra/loss.py | 2 +- .../experiments/vlm/infra/parallelize.py | 15 +- .../models/deepseek_v3/infra/parallelize.py | 50 +- torchtitan/models/flux/infra/parallelize.py | 20 +- torchtitan/models/flux/train.py | 15 +- torchtitan/models/flux/validate.py | 13 +- torchtitan/models/llama3/infra/parallelize.py | 24 +- torchtitan/models/llama4/infra/parallelize.py | 79 ++- torchtitan/models/qwen3/infra/parallelize.py | 50 +- torchtitan/train.py | 63 +- 32 files changed, 1200 insertions(+), 515 deletions(-) create mode 100644 tests/unit_tests/test_parallel_dims.py diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index bff9c2aa7f..ef310b5996 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -118,7 +118,7 @@ def test_generate( logger.info(f"Init model on init_device: {init_device}") model = train_spec.model_cls(model_args) - world_mesh = None + parallel_dims = None # Init distributed env if world_size > 1: dist_utils.init_distributed(config.comm) @@ -132,16 +132,26 @@ def test_generate( etp=1, world_size=world_size, ) - world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 # pyrefly: ignore [bad-argument-type] - apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.get_mesh("tp")) + else: + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) debug_config = DebugConfig(seed=seed, deterministic=deterministic) dist_utils.set_determinism( - world_mesh=world_mesh, + parallel_dims=parallel_dims, device=device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/tests/unit_tests/test_parallel_dims.py b/tests/unit_tests/test_parallel_dims.py new file mode 100644 index 0000000000..86b860065e --- /dev/null +++ b/tests/unit_tests/test_parallel_dims.py @@ -0,0 +1,569 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch.distributed as dist +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torchtitan.distributed import ParallelDims + + +class TestParallelDimsValidation(unittest.TestCase): + """Test ParallelDims validation logic without mesh building.""" + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_basic_initialization(self): + """Test basic initialization with valid parameters.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_replicate, 2) + self.assertEqual(parallel_dims.dp_shard, 2) + self.assertEqual(parallel_dims.cp, 1) + self.assertEqual(parallel_dims.tp, 2) + self.assertEqual(parallel_dims.pp, 1) + self.assertEqual(parallel_dims.ep, 1) + self.assertEqual(parallel_dims.etp, 1) + self.assertEqual(parallel_dims.world_size, 8) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_auto_calculate_dp_shard(self): + """Test automatic calculation of dp_shard when set to -1.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=-1, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_shard, 2) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_world_size(self): + """Test validation fails when parallelism degrees don't match world_size.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=10, # Invalid: 2*2*1*2*1 = 8, not 10 + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_etp(self): + """Test validation fails when etp is not equal to tp or 1.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=4, + pp=1, + ep=2, + etp=2, # Invalid: etp must be tp or 1 when ep > 1 + world_size=8, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_zero_parallelism(self): + """Test validation fails when parallelism degree is 0.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=0, # Invalid: must be >= 1 + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_dp_shard(self): + """Test validation fails when dp_shard is invalid (not -1 and not >=1).""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=0, # Invalid: must be -1 or >= 1 + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_enabled_properties(self): + """Test all enabled properties.""" + # Test with DP enabled + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with CP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=2, + ) + self.assertFalse(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.dp_cp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with EP and ETP enabled (EP * ETP must not contribute to world_size) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with PP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=2, + ep=1, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.pp_enabled) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_fsdp_gradient_divide_factor(self): + """Test fsdp_gradient_divide_factor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=3, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=12, + ) + # Should be dp_replicate * dp_shard * cp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.fsdp_gradient_divide_factor, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_non_data_parallel_size(self): + """Test non_data_parallel_size calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=2, + tp=3, + pp=2, + ep=1, + etp=1, + world_size=48, + ) + # Should be cp * tp * pp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.non_data_parallel_size, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_seq_len_divisor(self): + """Test seq_len_divisor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=1, + cp=2, + tp=4, + pp=1, + ep=1, + etp=1, + world_size=16, + ) + # Should be tp * (cp * 2) = 4 * 4 = 16 + self.assertEqual(parallel_dims.seq_len_divisor, 16) + + +class TestParallelDimsMeshOperations(unittest.TestCase): + """Test ParallelDims mesh operations with single-rank distributed environment.""" + + def setUp(self): + """Initialize distributed environment for CPU testing.""" + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:12356", + world_size=1, + rank=0, + ) + + def tearDown(self): + """Clean up distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_invalid_name(self): + """Test getting mesh with invalid name raises error.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + with self.assertRaises(ValueError) as context: + parallel_dims.get_mesh("invalid_mesh") + self.assertIn("Invalid mesh dim", str(context.exception)) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_lazy_initialization(self): + """Test that get_optional_mesh triggers build_mesh if not built yet.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + # Don't call build_mesh explicitly + self.assertEqual(len(parallel_dims._meshes), 0) + + # get_optional_mesh should trigger build_mesh + result = parallel_dims.get_optional_mesh("tp") + # Result is None because tp has size 1, but build_mesh should have been called + self.assertGreater(len(parallel_dims._meshes), 0) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_single_rank_mesh_operations(self): + """Comprehensive test for all single-rank mesh operations. + + This test verifies mesh building, mesh retrieval, mesh sizes, and property + access when all parallelism dimensions are set to 1 (single rank). + """ + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 1) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + + # Validate 1D mesh sizes - all should be 1 for single rank + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 1) + self.assertEqual(parallel_dims._meshes["fsdp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 1) + self.assertEqual(parallel_dims._meshes["batch"].size(), 1) + self.assertEqual(parallel_dims._meshes["loss"].size(), 1) + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual(parallel_dims._meshes["efsdp"].size(), 1) + + # Validate 2D mesh shapes + dp_replicate_fsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "fsdp"] + ) + self.assertIsNone(dp_replicate_fsdp_mesh) # Both dimensions have size 1 + dp_replicate_efsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "efsdp"] + ) + self.assertIsNone(dp_replicate_efsdp_mesh) # Both dimensions have size 1 + ep_etp_mesh = parallel_dims.get_optional_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 + + # Test get_optional_mesh returns None when all dimensions have size 1 + self.assertIsNone(parallel_dims.get_optional_mesh("tp")) + self.assertIsNone(parallel_dims.get_optional_mesh("dp_replicate")) + self.assertIsNone(parallel_dims.get_optional_mesh("pp")) + self.assertIsNone(parallel_dims.get_optional_mesh("cp")) + self.assertIsNone(parallel_dims.get_optional_mesh("fsdp")) + + # Test get_optional_mesh with list input + self.assertIsNone(parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"])) + + # Test get_all_one_dimensional_meshes returns empty when all dimensions have size 1 + one_d_meshes = parallel_dims.get_all_one_dimensional_meshes() + self.assertEqual(len(one_d_meshes), 0) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 1) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_with_list_input(self): + """Test get_optional_mesh accepts both string and list inputs.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + # Should accept list input + result = parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"]) + # Returns None because both dimensions have size 1 + self.assertIsNone(result) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_expert_parallelism_validation(self): + """Test expert parallelism configurations.""" + # EP with ETP = 1 (valid) - world_size = dp_replicate * dp_shard * cp * tp * pp + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, # 1 * 2 * 1 * 1 * 1 = 2 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with larger configuration + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=3, + etp=1, + world_size=4, # 2 * 2 * 1 * 1 * 1 = 4 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + + +class TestParallelDimsWorld8MeshOperations(DTensorTestBase): + """Test ParallelDims mesh operations with 8-rank distributed environment.""" + + @property + def world_size(self): + return 8 + + @with_comms + def test_world_size_8_mesh_operations(self): + """Comprehensive test for world_size=8 mesh operations. + + This test validates mesh building, mesh retrieval, mesh sizes, and properties + for a world_size=8 configuration with multiple parallelism dimensions enabled. + Configuration: dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1 (2*2*1*2*1 = 8) + """ + with patch( + "torchtitan.distributed.parallel_dims.device_type", self.device_type + ): + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 8) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + self.assertIn("ep", parallel_dims._meshes) + self.assertIn("etp", parallel_dims._meshes) + self.assertIn("efsdp", parallel_dims._meshes) + + # Validate 1D mesh sizes match parallelism configuration + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["batch"].size(), 4 + ) # dp_replicate * dp_shard = 2 * 2 + self.assertEqual( + parallel_dims._meshes["loss"].size(), 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 2) + self.assertEqual( + parallel_dims._meshes["fsdp"].size(), 2 + ) # dp_shard * cp = 2 * 1 + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 2) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["efsdp"].size(), 4 + ) # fsdp * tp / (etp * ep) = 2 * 2 / (1 * 1) = 4 + + # Validate 2D mesh shapes + dp_replicate_fsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp_mesh) + self.assertEqual( + dp_replicate_fsdp_mesh.shape, (2, 2) + ) # (dp_replicate, fsdp) + # efsdp mesh only exists when ep > 1, so dp_replicate_efsdp should be None when ep=1 + dp_replicate_efsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "efsdp"] + ) + self.assertIsNone(dp_replicate_efsdp_mesh) # efsdp disabled when ep=1 + ep_etp_mesh = parallel_dims.get_optional_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 + + # Test get_mesh returns valid meshes for enabled dimensions (size > 1) + self.assertIsNotNone(parallel_dims.get_mesh("tp")) + self.assertIsNotNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNotNone(parallel_dims.get_mesh("fsdp")) + self.assertIsNotNone(parallel_dims.get_mesh("batch")) + self.assertIsNotNone(parallel_dims.get_mesh("loss")) + + # Test get_optional_mesh returns None for disabled dimensions (size = 1) + self.assertIsNone(parallel_dims.get_optional_mesh("pp")) + self.assertIsNone(parallel_dims.get_optional_mesh("cp")) + self.assertIsNone(parallel_dims.get_optional_mesh("ep")) + + # Test get_mesh with 2D mesh names + self.assertIsNotNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + hsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertEqual(hsdp_mesh.shape, (2, 2)) + + # Test get_all_one_dimensional_meshes returns only meshes with size > 1 + one_d_meshes = parallel_dims.get_all_one_dimensional_meshes() + self.assertGreater(len(one_d_meshes), 0) + # Should include: dp_replicate, fsdp, tp, batch, loss, efsdp (all with size > 1) + self.assertIn("dp_replicate", one_d_meshes) + self.assertIn("fsdp", one_d_meshes) + self.assertIn("tp", one_d_meshes) + self.assertIn("batch", one_d_meshes) + self.assertIn("loss", one_d_meshes) + self.assertIn("efsdp", one_d_meshes) + # Should not include: pp, cp, ep, etp (all with size = 1) + self.assertNotIn("pp", one_d_meshes) + self.assertNotIn("cp", one_d_meshes) + self.assertNotIn("ep", one_d_meshes) + self.assertNotIn("etp", one_d_meshes) + + # Test that we can get 2D meshes via get_mesh() instead + dp_replicate_fsdp = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp) + self.assertEqual(dp_replicate_fsdp.ndim, 2) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 8) + + # Validate enabled properties + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + + # Validate calculated properties + self.assertEqual( + parallel_dims.fsdp_gradient_divide_factor, 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual( + parallel_dims.non_data_parallel_size, 2 + ) # cp * tp * pp = 1 * 2 * 1 + self.assertEqual( + parallel_dims.seq_len_divisor, 4 + ) # tp * (cp * 2) = 2 * (1 * 2) = 2 * 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index c8087731c5..2be196b7e1 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -13,8 +13,8 @@ from torchtitan.distributed.utils import set_determinism -class FakeDeviceMesh: - """Fake DeviceMesh for testing seed uniqueness. +class FakeParallelDims: + """Fake ParallelDims for testing seed uniqueness. Args: mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"]) @@ -26,25 +26,68 @@ def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): self.mesh_dim_names = mesh_dim_names self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes)) self.rank_coords = dict(zip(mesh_dim_names, rank_coords)) - - def __getitem__(self, key): - """Return a submesh for the given dimension(s).""" + # Calculate world_size as product of all mesh sizes + self.world_size = 1 + for size in mesh_sizes: + self.world_size *= size + + # Add individual parallelism degree attributes to match real ParallelDims interface + self.pp = self.mesh_sizes.get("pp", 1) + self.tp = self.mesh_sizes.get("tp", 1) + self.cp = self.mesh_sizes.get("cp", 1) + self.dp_replicate = self.mesh_sizes.get("dp_replicate", 1) + self.dp_shard = self.mesh_sizes.get("dp_shard", 1) + self.ep = self.mesh_sizes.get("ep", 1) + self.etp = self.mesh_sizes.get("etp", 1) + + # For backward compatibility with 'dp' dimension name + if "dp" in self.mesh_sizes: + self.dp_replicate = self.mesh_sizes["dp"] + + # Create a world_mesh mock + self.world_mesh = MagicMock() + self.world_mesh.device_type = "cpu" + + def get_mesh(self, key): + """Return a submesh for the given dimension.""" if isinstance(key, str): # Single dimension + if key not in self.mesh_dim_names: + return None submesh = MagicMock() submesh.get_local_rank.return_value = self.rank_coords[key] submesh.size.return_value = self.mesh_sizes[key] submesh.get_coordinate.return_value = self.rank_coords[key] + submesh.device_type = "cpu" return submesh elif isinstance(key, list): - # Multiple dimensions + # Multiple dimensions - check if all exist + if not all(dim in self.mesh_dim_names for dim in key): + return None submesh = MagicMock() # For multiple dimensions, get_coordinate should return None # since we're not testing this path submesh.get_coordinate.return_value = None + submesh.device_type = "cpu" return submesh else: - raise ValueError(f"Unsupported key type: {type(key)}") + return None + + def get_optional_mesh(self, key): + """Return a submesh for the given dimension, or None if not available. + + This is the same as get_mesh() for FakeParallelDims since get_mesh() + already returns None for unavailable meshes. + """ + return self.get_mesh(key) + + def get_all_meshes(self): + """Return a dict of all meshes.""" + return {dim: self.get_mesh(dim) for dim in self.mesh_dim_names} + + def __getitem__(self, key): + """Return a submesh for the given dimension(s) - for backward compatibility.""" + return self.get_mesh(key) def get_coordinate(self): """Return the coordinate tuple for this rank.""" @@ -85,12 +128,12 @@ def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_rank, pp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) # Call set_determinism with distinct seeds only on PP dimension debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], @@ -154,12 +197,14 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims( + mesh_dim_names, mesh_sizes, rank_coords + ) # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], @@ -218,12 +263,15 @@ def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): base_seed = 42 fake_mesh = MagicMock() - fake_mesh.mesh_dim_names = None - fake_mesh.get_coordinate.return_value = None + fake_mesh.world_size = 1 + fake_mesh.world_mesh = MagicMock() + fake_mesh.get_mesh.return_value = None + fake_mesh.get_optional_mesh.return_value = None + fake_mesh.get_all_meshes.return_value = {} debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2b08142f97..5ddc80852a 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -364,9 +364,7 @@ def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, ): - dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None - ) + loss_mesh = parallel_dims.get_optional_mesh("loss") # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] @@ -391,7 +389,7 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) - if dp_cp_mesh is not None: + if loss_mesh is not None: if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( placements=[Replicate()] @@ -399,7 +397,7 @@ def _update_expert_bias( ) else: # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() + pg = loss_mesh.get_group() torch.distributed.all_reduce( tokens_per_expert_by_layer, group=pg, diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 4673807347..3beae2e216 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -120,17 +120,16 @@ def validate( inputs = input_dict["input"] labels = labels.to(device_type) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + optional_context_parallel_ctx = None + if parallel_dims.cp_enabled: + cp_mesh = parallel_dims.get_mesh("cp") + optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( + cp_mesh=cp_mesh, cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if parallel_dims.cp_enabled - else None - ) if parallel_dims.pp_enabled: assert self.pp_schedule is not None @@ -176,7 +175,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_optional_mesh("loss") ) else: global_avg_loss = loss.item() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5ef99f6934..4c2333f30d 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -396,19 +396,7 @@ class Parallelism: """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. - Currently, it is supported with the following constraints: - - - when etp = tp: - - - cp <= ep <= dp_shard * cp - - ep % cp == 0 - - dp_shard * cp % ep == 0 - - - when etp = 1: - - - cp * tp <= ep <= dp_shard * cp * tp - - ep % (cp * tp) == 0 - - dp_shard * cp * tp % ep == 0 + Currently, etp is either 1 or is the same as tp. Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 60de27b276..ca5cdd1d54 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -200,8 +200,14 @@ def _token_dispatch(self, mod, inputs, device_mesh): # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors. # The grad_placements on inputs is set to Partial so that necessary # reductions are performed during backward. + + # NOTE: The mesh used here should be dense_mesh["tp"] as routed_input is + # technically wrapped with the dense_mesh["tp"] but this complicates + # the interface of ExpertTensorParallel and it doesn't matter as etp + # is almost always the same as tp or is 1. To avoid the complexity, + # we use the etp mesh here. routed_input = DTensor.from_local( - routed_input, device_mesh["tp"], (Replicate(),) + routed_input, device_mesh["etp"], (Replicate(),) ).to_local(grad_placements=(Partial(),)) inputs = (routed_input, num_tokens_per_expert) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 187a363097..86173ba78a 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -26,6 +26,7 @@ class ParallelDims: etp: int world_size: int + _meshes: dict[str, DeviceMesh] = field(default_factory=dict) _world_mesh: DeviceMesh | None = None def __post_init__(self): @@ -56,143 +57,253 @@ def _validate(self): if ep > 1: assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" - if etp == tp: - # EP would borrow all cp and some dp_shard degree - assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - elif etp == 1: - # EP would borrow all cp and tp and some dp_shard degree - assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + + def _mesh_exist(self, name: str, degree: int) -> bool: + if name == "efsdp": + # We always keep the efsdp if EP is larger than 1 because we need + # FSDP wrapping to help the MoE layers do mixed precision training. + return True if self.ep > 1 else False + return degree > 1 def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + """ + Build the device mesh with the required mesh dimensions. + + The following mesh dimensions will be created: + + pp: Pipeline Parallelism (PP). + batch: Used by data loading to determine the global batch size and which + part of the data each rank should read. This dimension includes both + ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for + this dimension to avoid unnecessary process group creation. + loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, + ``dp_shard``, and ``cp`` degrees, as all of them parallelize the data, + essentially require the weight gradients reduction. + dp_replicate: For DDP or HSDP replicate dimension. + fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. Note that + we always assume that when ``cp`` is used, FSDP is also applied to + utilize its weight all-gather and gradients reduce_scatter even if + there may be no data parallelism (e.g., global batch size is 1). + cp: Context Parallelism (CP). + tp: Tensor Parallelism (TP). + ep: Expert Parallelism (EP). + efsdp: FSDP in the EP region. + etp: TP in the EP region. + + Note: Most dimensions above are created by unflattening the world mesh, except for loss, + which is created by flattening the batch and cp dimensions. + This API performs the following unflatten operations from the world mesh: + + ["pp", "batch", "cp", "tp"] # dataloading_mesh + ["pp", "dp_replicate", "fsdp", "tp"] # dense_mesh + ["pp", "dp_replicate", "efsdp", "ep", "etp"] # sparse_mesh + + Note: DeviceMesh currently recreates the process group for each dimension. + It should share the process group for the same dim group to avoid unnecessary + process group creation. We can also use Fake to achieve a similar goal. + However, using Fake to avoid redundancy messing up the code. We only use Fake + when it is necessary. For now, we just let DeviceMesh create redundant process + group and wait for DeviceMesh to fix the issue. + """ + + def unflatten_mesh( + world_mesh: DeviceMesh, + dim_names: tuple[str, ...], + dim_degrees: tuple[int, ...], ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + """Unflatten the world mesh to create the required mesh dimensions. + + Uses fake backend for dimensions with degree 1 or for 'batch' dimension + to avoid unnecessary process group creation. + """ + backend_override = {} + for name, degree in zip(dim_names, dim_degrees, strict=True): + if (not self._mesh_exist(name, degree)) or name == "batch": + backend_override[name] = "fake" + + return world_mesh._unflatten( + 0, dim_degrees, dim_names, backend_override=backend_override + ) - return mesh + logger.info( + f"Building device mesh with parallelism: " + f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " + f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" + ) - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" + batch = self.dp_replicate * self.dp_shard + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + + self._world_mesh = init_device_mesh( + device_type, (self.world_size,), mesh_dim_names=("world",) + ) + dataloading_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "batch", "cp", "tp"), + (self.pp, batch, self.cp, self.tp), + ) + loss_mesh = dataloading_mesh["batch", "cp"]._flatten("loss_mesh") + dense_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "fsdp", "tp"), + (self.pp, self.dp_replicate, fsdp, self.tp), + ) + sparse_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "efsdp", "ep", "etp"), + (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ) + + self._global_meshes = { + "dataloading": dataloading_mesh, + "loss": loss_mesh, + "dense": dense_mesh, + "sparse": sparse_mesh, + } + + self._meshes = { + "pp": dataloading_mesh["pp"], + "batch": dataloading_mesh["batch"], + "loss": loss_mesh, + "dp_replicate": dense_mesh["dp_replicate"], + "fsdp": dense_mesh["fsdp"], + "cp": dataloading_mesh["cp"], + "tp": dataloading_mesh["tp"], + "ep": sparse_mesh["ep"], + "efsdp": sparse_mesh["efsdp"], + "etp": sparse_mesh["etp"], + } + + # Validate mesh sizes + self._validate_meshes() + + logger.info( + f"Successfully created meshes with active dimensions: " + f"{list(self.get_all_one_dimensional_meshes().keys())}" + ) + + return self._world_mesh + + def _validate_meshes(self): + """Validate that created meshes have the expected sizes.""" + expected_sizes = { + "pp": self.pp, + "batch": self.dp_replicate * self.dp_shard, + "loss": self.dp_replicate * self.dp_shard * self.cp, + "dp_replicate": self.dp_replicate, + "fsdp": self.dp_shard * self.cp, + "cp": self.cp, + "tp": self.tp, + "ep": self.ep, + "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + "etp": self.etp, + } + + for mesh_name, expected_size in expected_sizes.items(): + actual_size = self._meshes[mesh_name].size() + assert actual_size == expected_size, ( + f"Mesh '{mesh_name}' has unexpected size: " + f"expected {expected_size}, got {actual_size}" ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + def get_optional_mesh(self, dims: str | list[str]) -> DeviceMesh | None: + """Get a device mesh by dimension name(s), returning None if not enabled. + + Args: + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp'. + + Returns: + DeviceMesh for the requested dimension(s), or None if: + - The dimension size is 1 (parallelism not enabled) + - The dimension doesn't exist (except efsdp which can exist even if size is 1 when ep > 1) + + Raises: + ValueError: If the requested dimension name(s) is not valid. + """ + if not self._meshes: + self.build_mesh() + + if isinstance(dims, str): + dims = [dims] + + for mesh_name in dims: + if mesh_name not in self._meshes: + raise ValueError( + f"Invalid mesh dim: '{mesh_name}'. " + f"Valid dimensions are: {list(self._meshes.keys())}" + ) + + if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): + return None + + if len(dims) == 1: + return self._meshes[dims[0]] + else: + for global_mesh in self._global_meshes.values(): + assert global_mesh.mesh_dim_names is not None + if not set(dims).issubset(set(global_mesh.mesh_dim_names)): + continue + return global_mesh[tuple(dims)] + raise ValueError(f"Invalid mesh name combinations {dims}.") + + def get_mesh(self, dims: str | list[str]) -> DeviceMesh: + """Get a device mesh by dimension name(s), raising if not available. + + Args: + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp'. + + Returns: + DeviceMesh for the requested dimension(s). + + Raises: + ValueError: If the mesh is not available (dimension size = 1 or not enabled), + or if the requested dimension name(s) is not valid. + """ + mesh = self.get_optional_mesh(dims) + if mesh is None: + enabled_str = ( + "enabled (size > 1)" if isinstance(dims, str) else "all enabled" + ) + raise ValueError( + f"Mesh '{dims}' is not available. " + f"Ensure the corresponding parallelism dimension is {enabled_str}." + ) return mesh + def get_all_one_dimensional_meshes(self) -> dict[str, DeviceMesh]: + """Get all enabled one-dimensional device meshes. + + Returns a dictionary of enabled one-dimensional device meshes, allowing you to + access their process groups. + + Note: + Device meshes created with the Fake backend are still included in the results. + + Returns: + dict[str, DeviceMesh]: A dictionary mapping mesh dimension names to their + corresponding DeviceMesh objects. Only includes meshes where: + - ndim == 1 (one-dimensional) + - parallelism is enabled (size > 1) + + Example: + >>> parallel_dims = ParallelDims( + ... dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1, ep=1, etp=1, world_size=8 + ... ) + >>> meshes = parallel_dims.get_all_one_dimensional_meshes() + >>> print(meshes.keys()) + dict_keys(['dp_replicate', 'fsdp', 'tp', 'batch', 'loss', 'efsdp']) + """ + if not self._meshes: + self.build_mesh() + return {k: v for k, v in self._meshes.items() if v.ndim == 1 and v.size() > 1} + @property def world_mesh(self) -> DeviceMesh: - # doing late init so ParallelDims can still be used as a lightweight - # dataclass without having to initialize the world mesh if self._world_mesh is None: self._world_mesh = self.build_mesh() return self._world_mesh diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index bef597be24..d9b6d29a09 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -49,7 +49,7 @@ def pipeline_llm( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 811e062958..7790ab6683 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -30,7 +30,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, - mesh: DeviceMesh, + mesh: DeviceMesh | None, extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. @@ -38,7 +38,8 @@ def _dist_reduce( Args: x (torch.Tensor): Input tensor. reduceOp (str): Reduce operation to perform. - mesh (DeviceMesh): Device mesh to use for reduction. + mesh (DeviceMesh | None): Device mesh to use for reduction. + If None, no reduction is performed but simply convert the tensor to a float. extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction. Defaults to None. If provided, this all_reduce will be called for the extra process group, and then the result will be all_reduced for the mesh. @@ -50,13 +51,17 @@ def _dist_reduce( if extra_pg is not None: x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg) + if mesh is None: + return x.item() + assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() +# TODO: rename this to maybe_dist_max def dist_max( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -66,7 +71,7 @@ def dist_max( def dist_sum( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -76,7 +81,7 @@ def dist_sum( def dist_mean( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -85,7 +90,7 @@ def dist_mean( def set_determinism( - world_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, device: torch.device, debug_config: DebugConfig, distinct_seed_mesh_dims: list[str], @@ -103,9 +108,8 @@ def set_determinism( Args: world_mesh: Device mesh for distributed training device: Device to use + debug_config: Debug config to use distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across. - seed: Base seed value (if None, will be determined automatically) - deterministic: Whether to enable deterministic algorithms """ if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") @@ -128,7 +132,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed - if not world_mesh: + if parallel_dims.world_size == 1: if seed is not None: torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed % 2**32) @@ -143,25 +147,25 @@ def set_determinism( seed_tensor = torch.get_rng_state()[:8].to(device) torch.distributed.broadcast(seed_tensor, src=0) seed = seed_tensor.to("cpu").view(torch.uint64).item() + assert isinstance(seed, int) # Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims` # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. - distinct_dims_in_mesh = [ - dim - for dim in distinct_seed_mesh_dims - if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names + distinct_seed_meshes = [ + parallel_dims.get_optional_mesh(dim) for dim in distinct_seed_mesh_dims ] + distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None] + assert all(mesh is not None for mesh in distinct_seed_meshes) - if c10d.get_world_size() > 1 and distinct_dims_in_mesh: + if distinct_seed_meshes: # Each dimension contributes: local_rank * (product of all previous dimension sizes) # This guarantees uniqueness like multi-dimensional array indexing seed_offset = 0 cumulative_size = 1 - for dim in distinct_dims_in_mesh: - distinct_mesh = world_mesh[dim] + for distinct_mesh in distinct_seed_meshes: local_rank = distinct_mesh.get_local_rank() # Add contribution from this dimension seed_offset += local_rank * cumulative_size @@ -172,24 +176,10 @@ def set_determinism( seed %= 2**64 logger.debug( - f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" + f"Distinct dims {distinct_seed_mesh_dims}, Global rank {c10d.get_rank()} using seed: {seed}" ) - # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_mesh_dims = [ - name - # pyrefly: ignore [not-iterable] - for name in world_mesh.mesh_dim_names - if name not in distinct_dims_in_mesh - ] - duplicate_seed_mesh = ( - # pyrefly: ignore [bad-index] - world_mesh[duplicate_seed_mesh_dims] - if duplicate_seed_mesh_dims - else None - ) else: - duplicate_seed_mesh = world_mesh logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency. @@ -197,11 +187,14 @@ def set_determinism( # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] os.environ["PYTHONHASHSEED"] = str(seed % 2**32) - # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. - # IF PP is also used, this seed is unique per PP rank. - if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: - # pyrefly: ignore [bad-argument-type] - torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) + # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for + # all ranks of the SPMD mesh. If PP is also used, this seed is unique per PP rank. + # TODO: remove the need of passing in a mesh once + # torch.distributed.tensor._random.manual_seed doesn't require a mesh input. + if parallel_dims.world_size > parallel_dims.pp: + # We just need to pass the world_mesh as the device_id is the only information + # this API uses. + torch.distributed.tensor._random.manual_seed(seed, parallel_dims.world_mesh) def create_context_parallel_ctx( @@ -370,7 +363,10 @@ def _get_distributed_backend(enable_cpu_backend): return torch.distributed.get_world_size() -def set_pg_timeouts(timeout, world_mesh): +def set_pg_timeouts( + timeout: timedelta, + parallel_dims: ParallelDims, +): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -391,10 +387,11 @@ def set_pg_timeouts(timeout, world_mesh): # pyrefly: ignore [missing-attribute] device_module.synchronize() - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - # None represents the 'default' PG, not part of the mesh - groups.append(None) + groups: list[torch.distributed.ProcessGroup | None] = [ + mesh.get_group() + for mesh in parallel_dims.get_all_one_dimensional_meshes().values() + ] + [None] for group in groups: torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index 80dfcac9a3..68adb3c038 100644 --- a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -206,7 +206,7 @@ def monkey_patch_checks(moe): assert not list(moe.reorderer.buffers()) -def monkey_patch_local_map_moe(model, world_mesh): +def monkey_patch_local_map_moe(model, sparse_mesh): """ TODO: fix HOPs not restoring the original signature. TODO: fix tracing with local shapes so that we can use Shard placements @@ -239,7 +239,7 @@ def monkey_patch_local_map_moe(model, world_mesh): ), redistribute_inputs=True, in_grad_placements=None, - device_mesh=world_mesh, + device_mesh=sparse_mesh, ) for block in model.layers.children(): @@ -280,7 +280,13 @@ def parallelize_deepseekv3( job_config.experimental.comms_bucket_reorder_strategy ) - world_mesh = parallel_dims.world_mesh + sparse_names = ["dp_replicate", "efsdp", "ep", "etp"] + sparse_names = [ + name + for name in sparse_names + if parallel_dims.get_optional_mesh(name) is not None + ] + sparse_mesh = parallel_dims.get_mesh(sparse_names) def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -304,7 +310,7 @@ def input_fn(): assert parallel_dims.pp_enabled is False, "PP not supported yet" # apply local_map to MoE - monkey_patch_local_map_moe(model, world_mesh) + monkey_patch_local_map_moe(model, sparse_mesh) # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp @@ -324,7 +330,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + sparse_mesh, mp_policy=mp_policy, compile=job_config.compile, ) as autop: @@ -333,20 +339,21 @@ def input_fn(): possible_input_shardings = { # maps relative to mesh dim names used in torchtitan "dp_replicate": Shard(0), - "dp_shard": Shard(0), - "tp": Replicate(), + "efsdp": Shard(0), + "ep": Shard(0), + "etp": Replicate(), } # only used if loss parallel is enabled possible_output_shardings = { # maps relative to mesh dim names used in torchtitan - "dp_shard": Shard(0), - "tp": Shard(2), + "efsdp": Shard(0), + "etp": Shard(2), } assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names + name in possible_input_shardings for name in sparse_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names + possible_input_shardings[name] for name in sparse_mesh.mesh_dim_names ) out_sharding = x_sharding loss_parallel_enabled = ( @@ -356,7 +363,7 @@ def input_fn(): if loss_parallel_enabled: out_sharding = tuple( possible_output_shardings[name] - for name in world_mesh.mesh_dim_names + for name in sparse_mesh.mesh_dim_names if name != "dp_replicate" ) autop.add_input_constraints([x_sharding]) @@ -379,7 +386,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, sparse_mesh["etp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on diff --git a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py index d7fbae2622..27149f67f0 100644 --- a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py @@ -44,7 +44,13 @@ def parallelize_llama( job_config.experimental.comms_bucket_reorder_strategy ) - world_mesh = parallel_dims.world_mesh + dense_names = ["dp_replicate", "fsdp", "tp"] + dense_names = [ + name + for name in dense_names + if parallel_dims.get_optional_mesh(name) is not None + ] + dense_mesh = parallel_dims.get_mesh(dense_names) def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -88,7 +94,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + dense_mesh, mp_policy=mp_policy, compile=job_config.compile, ) as autop: @@ -97,20 +103,20 @@ def input_fn(): possible_input_shardings = { # maps relative to mesh dim names used in torchtitan "dp_replicate": Shard(0), - "dp_shard": Shard(0), + "fsdp": Shard(0), "tp": Replicate(), } # only used if loss parallel is enabled possible_output_shardings = { # maps relative to mesh dim names used in torchtitan - "dp_shard": Shard(0), + "fsdp": Shard(0), "tp": Shard(2), } assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names + name in possible_input_shardings for name in dense_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names + possible_input_shardings[name] for name in dense_mesh.mesh_dim_names ) out_sharding = x_sharding loss_parallel_enabled = ( @@ -120,7 +126,7 @@ def input_fn(): if loss_parallel_enabled: out_sharding = tuple( possible_output_shardings[name] - for name in world_mesh.mesh_dim_names + for name in dense_mesh.mesh_dim_names if name != "dp_replicate" ) autop.add_input_constraints([x_sharding]) @@ -141,7 +147,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, dense_mesh["tp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 997af9a2c4..2b2a1f5244 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -25,10 +25,12 @@ def disable_compile(job_config: JobConfig): job_config.compile.enable = original_value -def parallelize_inputs(world_mesh, args, kwargs): +def parallelize_inputs(parallel_dims, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return DTensor.from_local( + tensor, parallel_dims.get_mesh("tp"), [Replicate()] + ) return tensor dt_args = tree_map(to_dtensor, args) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index e097579cc0..551bf695c5 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -194,9 +194,7 @@ def parameters(self, *args, **kwargs) -> Any: def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" - dt_args, dt_kwargs = self.parallelize_inputs( - self.parallel_dims.world_mesh, args, kwargs - ) + dt_args, dt_kwargs = self.parallelize_inputs(self.parallel_dims, args, kwargs) if self.joint_graph_module is None: self.joint_graph_module = self.joint_graph_builder( diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 5035129008..e9818ba19f 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -86,10 +86,9 @@ def __init__(self, job_config: ForgeJobConfig): world_size=world_size, ) - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + dp_degree, dp_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 self.dp_degree, self.dp_rank = dp_degree, dp_rank @@ -102,7 +101,7 @@ def __init__(self, job_config: ForgeJobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], # same as `torchtitan/train.py` diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 66ad151dd0..6530eb18bb 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -169,7 +169,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -243,9 +243,7 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_optional_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() @@ -262,8 +260,8 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"]), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"]), + dist_utils.dist_mean(loss, parallel_dims.get_optional_mesh("loss")), + dist_utils.dist_max(loss, parallel_dims.get_optional_mesh("loss")), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -329,7 +327,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 9b2e75ac4f..4768dab659 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -62,8 +62,6 @@ def parallelize_gptoss( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh - assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -71,6 +69,10 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -91,7 +93,7 @@ def parallelize_gptoss( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, @@ -102,23 +104,13 @@ def parallelize_gptoss( apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + ep_etp_mesh=parallel_dims.get_optional_mesh("ep_etp"), etp_enabled=parallel_dims.etp_enabled, dual_pipe_v=dual_pipe_v, ) - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) - if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -130,18 +122,18 @@ def parallelize_gptoss( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -152,11 +144,7 @@ def parallelize_gptoss( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, ) if parallel_dims.dp_replicate_enabled: @@ -170,9 +158,9 @@ def parallelize_gptoss( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -263,7 +251,7 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, etp_enabled: bool, dual_pipe_v: bool = False, ): @@ -309,7 +297,7 @@ def apply_moe_ep_tp( # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = GptossExpertTensorParallel() if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 83e24d7dc1..3e2209dd6a 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -54,7 +54,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -87,25 +86,19 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": @@ -125,38 +118,38 @@ def parallelize_deepseekv3( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] - # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") - dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: experts_shard_dim = 0 - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * parallel_dims.ep + edp_mesh["efsdp"].size() * parallel_dims.ep > transformer_block.moe.experts.num_experts ): experts_shard_dim = 1 # when EP is enable, the routed experts' gradient reduction is done over - # dp_mod_ep_mesh instead of whole dp_mesh. + # edp_mesh instead of whole dp_mesh. # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh # to be consistent with data. # TODO (ruisizhang123): update the logic following the link below instead @@ -164,7 +157,7 @@ def parallelize_deepseekv3( # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, - dp_mod_ep_mesh, + edp_mesh, dp_mode, mp_policy=mp_policy, shard_dim=experts_shard_dim, diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 484d3d4747..d64a8b79fc 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -97,7 +97,7 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = parallel_dims.world_mesh["tp"] + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, tp_mesh, @@ -126,13 +126,13 @@ def parallelize_llama( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( @@ -142,7 +142,7 @@ def parallelize_llama( model = data_parallel( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), mode=dp_mode, mp_policy=mp_policy, ) diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 76233aeb87..aaf94a5023 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -20,13 +20,13 @@ def init_test(self): self.loss_fn = cross_entropy_loss data_parallel_shard_degree = -1 if self.mode == "replicate": - self.dp_mesh_dim_names = ("dp_replicate",) + self.dp_mesh_dim_names = ["dp_replicate"] data_parallel_replicate_degree = self.world_size elif self.mode == "fully_shard": - self.dp_mesh_dim_names = ("dp_shard_cp",) + self.dp_mesh_dim_names = ["fsdp"] data_parallel_replicate_degree = 1 elif self.mode == "hybrid_shard": - self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + self.dp_mesh_dim_names = ["dp_replicate", "fsdp"] data_parallel_replicate_degree = self.world_size // 2 else: raise ValueError(f"Unsupported mode {self.mode}") @@ -41,7 +41,6 @@ def init_test(self): etp=1, world_size=self.world_size, ) - self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() @@ -50,7 +49,7 @@ def get_input(self): return model, inputs, labels def run_fsdp2(self, model, inputs, labels, epoch=20): - fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names)) optim = self.optimizer(model.parameters(), lr=1e-4) losses = [] for _ in range(epoch): @@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20): def run_simple_fsdp(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) optim = self.optimizer(model.parameters(), lr=1e-4) @@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20): def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) # TODO: Add "inductor" backend when it's numerical issues are fixed diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py index a049d88d76..fcdb31f27d 100644 --- a/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py @@ -39,7 +39,6 @@ def parallelize_hf_transformers( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -64,11 +63,11 @@ def parallelize_hf_transformers( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -84,13 +83,13 @@ def parallelize_hf_transformers( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ("dp_replicate", "fsdp") else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ("fsdp",) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(list(dp_mesh_dim_names)), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -104,17 +103,18 @@ def parallelize_hf_transformers( logger.info("Applied FSDP to the model") if parallel_dims.cp_enabled: - model.set_cp_mesh(world_mesh["cp"]) + model.set_cp_mesh(parallel_dims.get_mesh("cp")) logger.info("Applied Context Parallel to the model") if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py index f05caf9abf..f27f884014 100644 --- a/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py @@ -287,7 +287,7 @@ def pipeline_hf_transformers( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..7a3a490fb7 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -104,7 +104,7 @@ def build_token_imbalance_ce_loss( # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: # DP split the batch dim B # CP split the sequence dim S - token_mesh = parallel_dims.world_mesh["dp_cp"] + token_mesh = parallel_dims.get_mesh("loss") ft_pg = ft_manager.loss_sync_pg loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index b6ada94d00..d87070bee6 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_vlm( the model must fit on GPU or CPU memory. """ assert isinstance(model.encoder, nn.Module) - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -74,14 +73,13 @@ def parallelize_vlm( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -100,11 +98,12 @@ def parallelize_vlm( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=job_config.compile.enable, ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 63fb910376..1295ec9bdb 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -55,7 +55,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -84,29 +83,24 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, ) @@ -130,18 +124,18 @@ def parallelize_deepseekv3( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -152,11 +146,7 @@ def parallelize_deepseekv3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -171,9 +161,9 @@ def parallelize_deepseekv3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index b27fa93a31..c12e6b0c78 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -29,14 +29,14 @@ def parallelize_flux( apply_ac(model, job_config.activation_checkpoint) if parallel_dims.fsdp_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -141,17 +141,17 @@ def parallelize_encoders( job_config: JobConfig, ): if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard") - else: - dp_mesh_dim_names = ("dp_shard",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) + dp_mesh = parallel_dims.get_mesh(names) fsdp_config: dict[str, Any] = { - "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + "mesh": dp_mesh, "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 3e008fba59..b6fea6dbe7 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -28,10 +28,10 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.parallel_dims.world_mesh, + self.parallel_dims, self.device, job_config.debug, - distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], + distinct_seed_mesh_dims=["fsdp", "dp_replicate"], ) # NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder). @@ -136,9 +136,11 @@ def forward_backward_step( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=self.parallel_dims.world_mesh["cp"], + optional_context_parallel_ctx = None + if self.parallel_dims.cp_enabled: + cp_mesh = self.parallel_dims.get_mesh("cp") + optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( + cp_mesh=cp_mesh, cp_buffers=[ latents, latent_pos_enc, @@ -156,9 +158,6 @@ def forward_backward_step( }, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if self.parallel_dims.cp_enabled - else None - ) with self.train_context(optional_context_parallel_ctx): with self.maybe_enable_amp: latent_noise_pred = model( diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 32fa7b9f55..3bfa204e74 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -220,9 +220,11 @@ def validate( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + optional_context_parallel_ctx = None + if parallel_dims.cp_enabled: + cp_mesh = parallel_dims.get_mesh("cp") + optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( + cp_mesh=cp_mesh, cp_buffers=[ latents, latent_pos_enc, @@ -240,9 +242,6 @@ def validate( }, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if parallel_dims.cp_enabled - else None - ) with self.validation_context(optional_context_parallel_ctx): with self.maybe_enable_amp: @@ -268,7 +267,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_optional_mesh("loss") ) else: global_avg_loss = loss.item() diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 63bbc19ff6..a77a87a921 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -61,7 +61,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -84,13 +83,14 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -111,15 +111,14 @@ def parallelize_llama( apply_compile(model, job_config.compile) if parallel_dims.fsdp_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - + # dp_mesh is the mesh for FSDP/HSDP + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -138,11 +137,12 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 112153390f..c87fac1a89 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -75,7 +75,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -86,6 +85,7 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + tp_mesh = None if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -98,29 +98,27 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + # tp_mesh might have been set above if tp_enabled, otherwise get it here + if tp_mesh is None: + tp_mesh = parallel_dims.get_mesh("tp") apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=tp_mesh, + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, ) @@ -141,21 +139,20 @@ def parallelize_llama( if model_compile_enabled: apply_compile(model, job_config.compile, parallel_dims.ep_enabled) - dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # dp_mesh is the mesh for FSDP/HSDP + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -166,11 +163,7 @@ def parallelize_llama( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -185,9 +178,9 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -301,7 +294,7 @@ def apply_fsdp( cpu_offload: bool = False, reshard_after_forward_policy: str = "default", ep_degree: int = 1, - dp_mod_ep_mesh: DeviceMesh | None = None, + edp_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, ): """ @@ -352,10 +345,10 @@ def apply_fsdp( for layer_id, transformer_block in model.layers.items(): # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping # - the router and the shared experts are sharded together with the TransformerBlock - # - the routed experts are sharded with the remaining dp_mod_ep_mesh + # - the routed experts are sharded with the remaining edp_mesh if transformer_block.moe_enabled and ep_degree > 1: fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fsdp_mod_ep_config["mesh"] = edp_mesh # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding @@ -364,10 +357,10 @@ def apply_fsdp( # on non-0 dim. For now it may not be worth the complexity to support # shard_placement_fn on the outer TransformerBlock-level FSDP. _experts_shard_placement_fn = None - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * ep_degree + edp_mesh["efsdp"].size() * ep_degree > transformer_block.moe.experts.num_experts ): _experts_shard_placement_fn = lambda param: Shard(1) @@ -476,8 +469,8 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, - etp_enabled: bool, + etp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, dual_pipe_v: bool = False, ): assert ep_mesh is not None or tp_mesh is not None @@ -502,7 +495,7 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } - if ep_mesh is not None and not etp_enabled: + if ep_mesh is not None and etp_mesh is None: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) @@ -531,15 +524,17 @@ def apply_moe_ep_tp( experts_mesh, experts_plan = None, None if ep_mesh is None: + assert ep_etp_mesh is None experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() - elif tp_mesh is None or not etp_enabled: + elif tp_mesh is None or etp_mesh is None: + assert ep_etp_mesh is None experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index c2eaed8de6..4c7f43f426 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -58,7 +58,6 @@ def parallelize_qwen3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -91,9 +90,10 @@ def parallelize_qwen3( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, @@ -102,18 +102,13 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + tp_mesh = parallel_dims.get_mesh("tp") apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=tp_mesh, + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, ) @@ -133,18 +128,18 @@ def parallelize_qwen3( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -155,11 +150,7 @@ def parallelize_qwen3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -174,11 +165,12 @@ def parallelize_qwen3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 8c597cd608..8455e54eb9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,7 +13,6 @@ from typing import Any, Iterable import torch - import torch.distributed.checkpoint.stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -92,19 +91,15 @@ def __init__(self, job_config: JobConfig): # init distributed and build meshes self.parallel_dims = parallel_dims = self.init_distributed() - # Logging needs to happen after distributed initialized - job_config.maybe_log() - - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: - dp_degree, dp_rank = 1, 0 + batch_degree, batch_rank = 1, 0 # pyrefly: ignore [bad-argument-type] self.ft_manager = FTManager(job_config.fault_tolerance) - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( @@ -114,7 +109,7 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], @@ -129,8 +124,8 @@ def __init__(self, job_config: JobConfig): ) self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, job_config=job_config, ) @@ -199,19 +194,20 @@ def __init__(self, job_config: JobConfig): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - global_batch_size = job_config.training.local_batch_size * dp_degree + global_batch_size = job_config.training.local_batch_size * batch_degree assert global_batch_size > 0 assert ( - global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + global_batch_size % (job_config.training.local_batch_size * batch_degree) + == 0 ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({global_batch_size} " - f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + f"% ({job_config.training.local_batch_size} * {batch_degree}) != 0)" ) # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree + job_config.training.local_batch_size * batch_degree ) assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( @@ -348,8 +344,8 @@ def __init__(self, job_config: JobConfig): self.validator = self.train_spec.build_validator_fn( job_config=job_config, - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, parallel_dims=parallel_dims, loss_fn=self.loss_fn, @@ -487,24 +483,24 @@ def forward_backward_step( ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage - cp_buffers = [inputs, labels] + cp_buffers: list[torch.Tensor] = [inputs, labels] cp_seq_dims = [1, 1] if hasattr(model_parts[0], "freqs_cis"): - cp_buffers += [m.freqs_cis for m in model_parts] + for m in model_parts: + assert isinstance(m.freqs_cis, torch.Tensor) + cp_buffers.append(m.freqs_cis) cp_seq_dims += [0 for _ in model_parts] - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - # pyrefly: ignore [bad-argument-type] + optional_context_parallel_ctx = None + if parallel_dims.cp_enabled: + cp_mesh = parallel_dims.get_mesh("cp") + optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( + cp_mesh=cp_mesh, cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if parallel_dims.cp_enabled - else None - ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call @@ -576,9 +572,7 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_optional_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() @@ -595,14 +589,15 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg + loss_mesh = parallel_dims.get_optional_mesh("loss") global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, loss_mesh, ft_pg), + dist_utils.dist_max(loss, loss_mesh, ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - parallel_dims.world_mesh["dp_cp"], + loss_mesh, ft_pg, ), ) @@ -703,7 +698,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: From 36a4b69426f4682fcfadf6d5d2cf2768bae0fbfa Mon Sep 17 00:00:00 2001 From: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Date: Wed, 17 Dec 2025 22:51:49 -0800 Subject: [PATCH 074/127] Integrate DeepEP to torchtitan (#2107) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This initial version integrates DeepEP into TorchTitan, focusing on correctness and compatibility rather than maximal performance tuning. - Functional DeepEP-backed MoE + Expert Parallelism - User-controlled configuration - Compatible with torch.compile and SAC - Intended as a first unblocker for benchmarking and iteration ## Perf: DeepSeek-V3 671B on 64 nodes × H100 (512 GPUs total)
    Training config (click to expand) ``` config_path="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml", command_args=[ "--training.dataset_path=/lustre/fsw/portfolios/sw/users/elfieg/hf_datasets/c4", "--training.seq_len=4096", "--training.steps=120", "--metrics.log_freq=10", "--profiling.no-enable-profiling", "--comm.init_timeout_seconds=2000", "--comm.train_timeout_seconds=300", "--metrics.disable_color_printing", # Parallelism "--parallelism.data_parallel_replicate_degree=1", "--parallelism.data_parallel_shard_degree=64", "--parallelism.fsdp_reshard_after_forward=default", "--parallelism.tensor_parallel_degree=1", "--parallelism.expert_parallel_degree=32", "--parallelism.expert_tensor_parallel_degree=1", "--parallelism.pipeline_parallel_degree=8", "--parallelism.pipeline_parallel_schedule=Interleaved1F1B", # Training "--training.local_batch_size=16", "--activation_checkpoint.mode=full", # Compilation "--compile.enable", "--compile.components=model", "--compile.components=loss", # MoE / DeepEP "--debug.moe_force_load_balance", "--parallelism.expert_parallel_comm_backend=deepep", ], ```
    After: ``` memory: 56.75GiB(71.74%) tps: 579 tflops: 162.82 mfu: 16.46% ``` Before: ``` memory: 60.18GiB(76.07%) tps: 346 tflops: 97.24 mfu: 9.83% ``` ## Loss Curve: Screenshot 2025-12-16 at 11 30 02 PM Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports! --- torchtitan/config/job_config.py | 12 +- torchtitan/distributed/__init__.py | 6 +- torchtitan/distributed/deepep/__init__.py | 15 + torchtitan/distributed/deepep/deepep.py | 462 ++++++++++++++++++ torchtitan/distributed/expert_parallel.py | 67 +++ .../models/deepseek_v3/infra/parallelize.py | 27 +- torchtitan/models/deepseek_v3/model/args.py | 7 +- torchtitan/models/deepseek_v3/model/model.py | 10 +- torchtitan/models/llama4/infra/parallelize.py | 47 +- torchtitan/models/moe/__init__.py | 4 +- torchtitan/models/moe/moe.py | 19 + torchtitan/models/moe/moe_deepep.py | 58 +++ 12 files changed, 717 insertions(+), 17 deletions(-) create mode 100644 torchtitan/distributed/deepep/__init__.py create mode 100644 torchtitan/distributed/deepep/deepep.py create mode 100644 torchtitan/models/moe/moe_deepep.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4c2333f30d..504f00419f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import json - import os from dataclasses import asdict, dataclass, field from typing import Any, Literal @@ -412,6 +411,17 @@ class Parallelism: Note that this is still an experimental feature. """ + expert_parallel_comm_backend: Literal["standard", "deepep"] = "standard" + """ + Expert-parallel communication backend. No effect for non-MoE models or when ep = 1. + + - "standard": Uses PyTorch all-to-all collectives (default) + - "deepep": Uses DeepEP custom kernels for more efficient communication + + DeepEP requires installation: + https://github.com/deepseek-ai/DeepEP. + """ + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/__init__.py b/torchtitan/distributed/__init__.py index f335916595..72d1298648 100644 --- a/torchtitan/distributed/__init__.py +++ b/torchtitan/distributed/__init__.py @@ -14,8 +14,10 @@ from torchtitan.distributed.parallel_dims import ParallelDims - -__all__ = ["ParallelDims", "NoParallel"] +__all__ = [ + "ParallelDims", + "NoParallel", +] # NOTE: This is to achieve replicate computation on the gate module in the MoE router. diff --git a/torchtitan/distributed/deepep/__init__.py b/torchtitan/distributed/deepep/__init__.py new file mode 100644 index 0000000000..53001938a8 --- /dev/null +++ b/torchtitan/distributed/deepep/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""DeepEP distributed communication primitives for MoE.""" + +from .deepep import combine_tokens, dispatch_tokens, DispatchState + +__all__ = [ + "dispatch_tokens", + "combine_tokens", + "DispatchState", +] diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py new file mode 100644 index 0000000000..9389fac5c5 --- /dev/null +++ b/torchtitan/distributed/deepep/deepep.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +DeepEP primitives for MoE Expert Parallel. + +Provides low-level functions and autograd wrappers for DeepEP communication. +Used by DeepEPExpertParallel in expert_parallel.py. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.distributed import ProcessGroup + +try: + from deep_ep import Buffer # pyrefly: ignore [missing-import] + from deep_ep.utils import ( # pyrefly: ignore [missing-import] + EventHandle, + EventOverlap, + ) +except ImportError as e: + raise ImportError( + "DeepEP is required for this module. " + "Install from: https://github.com/deepseek-ai/deepep" + ) from e + + +# Global buffer (single buffer per process, recreated if group changes) +_buffer: Buffer = None + +# Global cache for dispatch handles, keyed by cache_id +# SAC saves the cache_id tensor; we use it to retrieve the non-tensor handle +_handle_cache: dict = {} +_cache_counter: int = 0 + + +def _get_next_cache_id() -> torch.Tensor: + """Generate a unique cache_id tensor on CPU to avoid GPU-CPU sync.""" + global _cache_counter + _cache_counter += 1 + return torch.tensor([_cache_counter], dtype=torch.int64, device="cpu") + + +# ============================================================================ +# Custom Op Registration for SAC Integration +# ============================================================================ + +_lib = torch.library.Library("deepep", "DEF") + +# dispatch returns: (recv_x, recv_indices, recv_scores, num_recv_per_expert, cache_id) +_lib.define( + "dispatch(Tensor x, Tensor topk_idx, Tensor topk_weights, " + "Tensor num_tokens_per_rank, Tensor num_tokens_per_rdma_rank, " + "Tensor is_token_in_rank, Tensor num_tokens_per_expert) " + "-> (Tensor, Tensor, Tensor, Tensor, Tensor)" +) + +# combine returns: combined_x +_lib.define("combine(Tensor x, Tensor cache_id) -> Tensor") + + +@torch.library.impl(_lib, "dispatch", "CUDA") +def _dispatch_op_impl( + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_tokens_per_rank: torch.Tensor, + num_tokens_per_rdma_rank: torch.Tensor, + is_token_in_rank: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Execute DeepEP dispatch.""" + global _buffer + + buffer = _buffer + assert buffer is not None, "Buffer must be initialized before dispatch" + + previous_event = _create_event_if_async(True) + + ( + recv_x, + recv_indices, + recv_scores, + num_recv_list, + handle, + after_event, + ) = buffer.dispatch( + x=x, + topk_idx=topk_idx, + topk_weights=topk_weights.to(torch.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + cache_id = _get_next_cache_id() + _handle_cache[cache_id.item()] = handle + + num_recv_tensor = torch.tensor(num_recv_list, dtype=torch.int32, device="cpu") + return recv_x, recv_indices, recv_scores, num_recv_tensor, cache_id + + +@torch.library.impl(_lib, "combine", "CUDA") +def _combine_op_impl(x: torch.Tensor, cache_id: torch.Tensor) -> torch.Tensor: + """Execute DeepEP combine.""" + global _buffer + + buffer = _buffer + assert buffer is not None, "Buffer must be initialized before combine" + + handle = _handle_cache.get(cache_id.item()) + assert handle is not None, f"Handle not found for cache_id={cache_id.item()}" + + previous_event = _create_event_if_async(True) + + combined, _, after_event = buffer.combine( + x=x, + handle=handle, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + return combined + + +def _dispatch_backward( + ctx, grad_recv_x, grad_recv_indices, grad_recv_scores, grad_num_recv, grad_cache_id +): + """Backward for dispatch: performs combine on gradients.""" + global _buffer + + if grad_recv_x is None: + return None, None, None, None, None, None, None + + handle = _handle_cache.get(ctx.cache_id_int) + assert handle is not None, f"Handle not found for cache_id={ctx.cache_id_int}" + + previous_event = _create_event_if_async(True) + + grad_x, grad_scores, after_event = _buffer.combine( + x=grad_recv_x, + handle=handle, + topk_weights=grad_recv_scores.float() if grad_recv_scores is not None else None, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + _handle_cache.pop(ctx.cache_id_int, None) + + grad_x = grad_x.to(ctx.input_dtype) + grad_topk_weights = ( + grad_scores.to(ctx.input_dtype) if grad_scores is not None else None + ) + + return grad_x, None, grad_topk_weights, None, None, None, None + + +def _dispatch_setup_context(ctx, inputs, output): + x, topk_idx, topk_weights, *_ = inputs + recv_x, recv_indices, recv_scores, num_recv, cache_id = output + ctx.cache_id_int = cache_id.item() + ctx.input_dtype = x.dtype + + +def _combine_backward(ctx, grad_combined): + """Backward for combine: performs dispatch on gradients.""" + global _buffer + + handle = ctx.saved_handle + previous_event = _create_event_if_async(True) + + grad_x, _, _, _, _, after_event = _buffer.dispatch( + x=grad_combined, + topk_idx=None, + topk_weights=None, + num_tokens_per_rank=None, + num_tokens_per_rdma_rank=None, + is_token_in_rank=None, + num_tokens_per_expert=None, + handle=handle, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + return grad_x, None + + +def _combine_setup_context(ctx, inputs, output): + x, cache_id = inputs + ctx.cache_id_int = cache_id.item() + ctx.saved_handle = _handle_cache.get(ctx.cache_id_int) + + +torch.library.register_autograd( + "deepep::dispatch", _dispatch_backward, setup_context=_dispatch_setup_context +) +torch.library.register_autograd( + "deepep::combine", _combine_backward, setup_context=_combine_setup_context +) + + +def _create_event_if_async(async_finish: bool): + """Create EventOverlap handle if async mode is enabled.""" + return EventOverlap(EventHandle()) if async_finish else None + + +def _sync_stream_if_async(async_finish: bool, after_event): + """Synchronize current stream with communication stream if async mode is enabled.""" + if async_finish and after_event is not None: + after_event.current_stream_wait() + + +def get_hidden_bytes(x: torch.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor.""" + return x.size(1) * max(x.element_size(), 2) + + +def get_buffer(group: ProcessGroup, hidden_bytes: int) -> Buffer: + """Get or create a buffer for all-to-all communication.""" + global _buffer + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max( + config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes + ) + num_rdma_bytes = max( + config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes + ) + + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + + return _buffer + + +def _indices_to_multihot( + indices: torch.Tensor, scores: torch.Tensor, num_local_experts: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert topk indices to multihot format for permutation.""" + batch_size = indices.shape[0] + multihot_routing_map = torch.zeros( + (batch_size, num_local_experts), dtype=torch.long, device=indices.device + ) + multihot_scores = torch.zeros( + (batch_size, num_local_experts), dtype=scores.dtype, device=indices.device + ) + + mask = indices != -1 + valid_indices = indices[mask] + row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave( + mask.sum(dim=1) + ) + multihot_routing_map[row_indices, valid_indices] = 1 + multihot_scores[row_indices, valid_indices] = scores[mask] + + return multihot_routing_map.bool(), multihot_scores + + +def _permute_tokens( + tokens: torch.Tensor, + routing_map: torch.Tensor, + scores: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """Permute tokens by expert for grouped_mm. + + Returns: + (permuted_tokens, permuted_scores, sorted_indices) + """ + num_tokens = tokens.shape[0] + num_experts = routing_map.shape[1] + + routing_map_t = routing_map.bool().T.contiguous() + token_indices = torch.arange(num_tokens, device=routing_map.device) + token_indices = token_indices.unsqueeze(0).expand(num_experts, -1) + sorted_indices = token_indices.masked_select(routing_map_t) + sorted_tokens = tokens.index_select(0, sorted_indices) + + if scores is not None: + sorted_scores = scores.T.contiguous().masked_select(routing_map_t) + else: + sorted_scores = None + + return sorted_tokens, sorted_scores, sorted_indices + + +def _unpermute_tokens( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Reverse permutation applied by _permute_tokens.""" + hidden = permuted_tokens.shape[1] + output_tokens = torch.zeros( + (num_tokens, hidden), dtype=permuted_tokens.dtype, device=permuted_tokens.device + ) + output_tokens.scatter_add_( + 0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens + ) + return output_tokens + + +@dataclass +class DispatchState: + """State from dispatch needed for combine.""" + + cache_id: torch.Tensor # CPU tensor used to retrieve cached handle + sorted_indices: torch.Tensor + num_recv_tokens: int + permuted_scores: Optional[torch.Tensor] = None + + +def dispatch_tokens( + hidden_states: torch.Tensor, + selected_experts_indices: torch.Tensor, + top_scores: torch.Tensor, + num_local_experts: int, + num_experts: int, + group: ProcessGroup, + score_before_experts: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, DispatchState]: + """Dispatch tokens to experts via DeepEP. + + Args: + hidden_states: Input tokens [num_tokens, hidden_dim] + selected_experts_indices: Expert indices for each token [num_tokens, top_k] + top_scores: Routing scores for each token [num_tokens, top_k] + num_local_experts: Number of experts on this rank + num_experts: Total number of experts across all ranks + group: EP process group + score_before_experts: If True, apply routing scores before expert computation. + + Returns: + (permuted_tokens, tokens_per_expert, state_for_combine) + """ + # Ensure contiguous and proper shape + router_topk = ( + selected_experts_indices.shape[1] if selected_experts_indices.dim() == 2 else 1 + ) + if selected_experts_indices.dim() != 2: + selected_experts_indices = selected_experts_indices.view( + -1, router_topk + ).contiguous() + top_scores = top_scores.view(-1, router_topk).contiguous() + else: + selected_experts_indices = selected_experts_indices.contiguous() + top_scores = top_scores.contiguous() + + # Mask out zero-score tokens + selected_experts_indices = selected_experts_indices.masked_fill(top_scores == 0, -1) + + # Ensure float32 scores (DeepEP requirement) + if top_scores.dtype != torch.float32: + top_scores = top_scores.float() + + buffer = get_buffer(group, get_hidden_bytes(hidden_states)) + + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert_dispatch, + is_token_in_rank, + _, + ) = buffer.get_dispatch_layout( + topk_idx=selected_experts_indices, num_experts=num_experts + ) + + ( + hidden_states, + dispatched_indices, + dispatched_expert_scores, + tokens_per_expert, + cache_id, + ) = torch.ops.deepep.dispatch( + hidden_states, + selected_experts_indices, + top_scores, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert_dispatch, + ) + + dispatched_routing_map, dispatched_expert_scores_multihot = _indices_to_multihot( + dispatched_indices, dispatched_expert_scores, num_local_experts + ) + + num_recv_tokens = hidden_states.shape[0] + + # Sort tokens by expert for grouped_mm + hidden_states, permuted_scores, sorted_indices = _permute_tokens( + hidden_states, dispatched_routing_map, scores=dispatched_expert_scores_multihot + ) + + # Compute tokens_per_expert from routing_map (matches the sorted tokens) + tokens_per_expert = ( + dispatched_routing_map.sum(dim=0).to(torch.int32).to(hidden_states.device) + ) + + if score_before_experts and permuted_scores is not None: + # Avoid float32 conversion to save memory + hidden_states = hidden_states * permuted_scores.to(hidden_states.dtype).reshape( + -1, 1 + ) + permuted_scores_for_state = None + else: + permuted_scores_for_state = permuted_scores + + state = DispatchState( + cache_id=cache_id, + sorted_indices=sorted_indices, + num_recv_tokens=num_recv_tokens, + permuted_scores=permuted_scores_for_state, + ) + + return hidden_states, tokens_per_expert, state + + +def combine_tokens( + hidden_states: torch.Tensor, + state: DispatchState, +) -> torch.Tensor: + """Combine tokens from experts via DeepEP.""" + if state.permuted_scores is not None: + # In-place multiplication to save memory + hidden_states = hidden_states * state.permuted_scores.to( + hidden_states.dtype + ).reshape(-1, 1) + + hidden_states = _unpermute_tokens( + hidden_states, state.sorted_indices, state.num_recv_tokens + ) + + hidden_states = torch.ops.deepep.combine(hidden_states, state.cache_id) + + return hidden_states diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index ca5cdd1d54..8ee53e754e 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -315,3 +315,70 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: # pyrefly: ignore [bad-argument-type] output_fn=self._prepare_output_fn, ) + + +class DeepEPExpertParallel(BaseExpertParallel): + """Expert Parallel using DeepEP for efficient token dispatch/combine. + + Expects inputs as: + (hidden_states, num_tokens_per_expert, selected_experts_indices, top_scores, num_experts) + + Args: + score_before_experts: If True, apply routing scores before expert computation. + """ + + def __init__(self, score_before_experts: bool = True): + super().__init__() + self._state = None # State preserved between dispatch and combine + self.score_before_experts = score_before_experts + + def _token_dispatch(self, mod, inputs, device_mesh): + """Dispatch tokens via DeepEP.""" + from torchtitan.distributed.deepep import dispatch_tokens + + hidden_states, _, selected_experts_indices, top_scores, num_experts = inputs + if isinstance(mod.w1, DTensor): + num_local_experts = mod.w1.to_local().shape[0] + else: + num_local_experts = mod.w1.shape[0] + ep_group = device_mesh.get_group() + + hidden_states, tokens_per_expert, self._state = dispatch_tokens( + hidden_states, + selected_experts_indices, + top_scores, + num_local_experts, + num_experts, + ep_group, + score_before_experts=self.score_before_experts, + ) + + return hidden_states, tokens_per_expert + + @staticmethod + def _partition_fn(name, mod, device_mesh): + """Shard expert weights on expert dimension.""" + for param_name, param in mod.named_parameters(recurse=False): + mod.register_parameter( + param_name, + nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])), + ) + + def _token_combine(self, mod, routed_output, device_mesh): + """Combine tokens via DeepEP.""" + from torchtitan.distributed.deepep import combine_tokens + + # pyrefly: ignore [bad-argument-type] + routed_output = combine_tokens(routed_output, self._state) + self._state = None + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + """Apply DeepEP parallelization.""" + return distribute_module( + module, + device_mesh, + partition_fn=DeepEPExpertParallel._partition_fn, + input_fn=self._token_dispatch, # pyrefly: ignore [bad-argument-type] + output_fn=self._token_combine, # pyrefly: ignore [bad-argument-type] + ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1295ec9bdb..374df44157 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -15,7 +15,6 @@ RowwiseParallel, SequenceParallel, ) - from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac @@ -29,7 +28,6 @@ ) from torchtitan.tools.logging import logger - # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -92,6 +90,30 @@ def parallelize_deepseekv3( ) maybe_enable_async_tp(job_config, tp_mesh) + # Check if using DeepEP for MoE communication + if job_config.parallelism.expert_parallel_comm_backend == "deepep": + if not parallel_dims.ep_enabled: + raise ValueError( + "DeepEP requires expert parallelism (ep_degree > 1). " + "The DeepEP MoE model code does not support EP=1. " + "Please set expert_parallel_degree > 1 or use standard communication backend." + ) + if parallel_dims.etp_enabled: + raise NotImplementedError( + "DeepEP with Expert Tensor Parallelism (ETP) is not supported yet. " + "Please set expert_tensor_parallel_degree=1 or use standard communication backend." + ) + + use_deepep = True + + # Import deepep module to register custom ops before accessing them + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) + _op_sac_save_list.add(torch.ops.deepep.combine.default) + else: + use_deepep = False + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) @@ -102,6 +124,7 @@ def parallelize_deepseekv3( etp_mesh=parallel_dims.get_optional_mesh("etp"), ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, + use_deepep=use_deepep, ) model_compile_enabled = ( diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 6609e6fa4e..f880a53384 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from torch import nn - from torchtitan.config import JobConfig from torchtitan.models.moe import MoEArgs from torchtitan.models.utils import get_moe_model_nparams_and_flops @@ -65,6 +64,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): # MoE moe_args: MoEArgs = field(default_factory=MoEArgs) + # Expert parallel communication backend (set from config) + expert_parallel_comm_backend: str = "standard" # "standard" or "deepep" + # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 kv_lora_rank: int = 512 @@ -106,6 +108,9 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) + # Configure expert parallel communication backend from config (defaults to "standard") + self.moe_impl = job_config.parallelism.expert_parallel_comm_backend + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 26e0cff2f3..fdc6ef56a7 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,9 +8,7 @@ import torch from torch import nn - from torch.nn.attention.flex_attention import and_masks, BlockMask - from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.attention import ( create_attention_mask, @@ -19,7 +17,7 @@ get_document_mask_mod, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import build_moe, FeedForward from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -351,10 +349,11 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: - self.moe = MoE( - model_args.moe_args, + self.moe = build_moe( + args=model_args.moe_args, dim=model_args.dim, hidden_dim=model_args.moe_inter_dim, + moe_impl=model_args.moe_impl, ) else: self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) @@ -395,6 +394,7 @@ def init_weights(self, buffer_device: torch.device): norm.reset_parameters() self.attention.init_weights(self.weight_init_std) if self.moe_enabled: + # pyrefly: ignore [not-callable, missing-attribute] self.moe.init_weights(self.weight_init_std, buffer_device) else: self.feed_forward.init_weights(self.weight_init_std) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index c87fac1a89..454679ff55 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -30,9 +30,9 @@ DualPipeExpertParallel, get_dual_pipe_v_flag, ) - from torchtitan.distributed.expert_parallel import ( BaseExpertParallel, + DeepEPExpertParallel, ExpertParallel, ExpertTensorParallel, ReordererSequenceParallel, @@ -43,7 +43,6 @@ from torchtitan.models.moe import moe as moe_module from torchtitan.tools.logging import logger - # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -107,6 +106,30 @@ def parallelize_llama( ) maybe_enable_async_tp(job_config, tp_mesh) + # Check if using DeepEP for MoE communication + if job_config.parallelism.expert_parallel_comm_backend == "deepep": + if not parallel_dims.ep_enabled: + raise ValueError( + "DeepEP requires expert parallelism (ep_degree > 1). " + "The DeepEP MoE model code does not support EP=1. " + "Please set expert_parallel_degree > 1 or use standard communication backend." + ) + if parallel_dims.etp_enabled: + raise NotImplementedError( + "DeepEP with Expert Tensor Parallelism (ETP) is not supported yet. " + "Please set expert_tensor_parallel_degree=1 or use standard communication backend." + ) + + use_deepep = True + + # Import deepep module to register custom ops before accessing them + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) + _op_sac_save_list.add(torch.ops.deepep.combine.default) + else: + use_deepep = False + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) @@ -120,12 +143,18 @@ def parallelize_llama( etp_mesh=parallel_dims.get_optional_mesh("etp"), ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), dual_pipe_v=dual_pipe_v, + use_deepep=use_deepep, ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) if job_config.activation_checkpoint.mode != "none": + if job_config.activation_checkpoint.selective_ac_option == "op": + logger.info( + f"SAC save list contains {len(_op_sac_save_list)} ops: " + f"{sorted([str(op) for op in _op_sac_save_list])}" + ) apply_ac( model, job_config.activation_checkpoint, @@ -472,6 +501,7 @@ def apply_moe_ep_tp( etp_mesh: DeviceMesh | None, ep_etp_mesh: DeviceMesh | None, dual_pipe_v: bool = False, + use_deepep: bool = False, ): assert ep_mesh is not None or tp_mesh is not None @@ -531,8 +561,17 @@ def apply_moe_ep_tp( elif tp_mesh is None or etp_mesh is None: assert ep_etp_mesh is None experts_mesh = ep_mesh - # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() + if use_deepep: + # pyrefly: ignore [missing-attribute] + score_before_experts = transformer_block.moe.score_before_experts + + experts_plan = DeepEPExpertParallel( + score_before_experts=score_before_experts, + ) + logger.info("Applying DeepEP to MoE layer") + else: + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() else: experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index c8247ec7fb..5ae6250e17 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import FeedForward, MoE, MoEArgs +from .moe import build_moe, FeedForward, MoE, MoEArgs -__all__ = ["FeedForward", "MoE", "MoEArgs"] +__all__ = ["FeedForward", "MoE", "MoEArgs", "build_moe"] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index c5dd59ab29..90e6418972 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -12,6 +12,7 @@ from torch import nn from torch.distributed.tensor import DTensor +from torchtitan.tools.logging import logger from .utils import indices_padding_wrapper @@ -565,3 +566,21 @@ def init_weights( self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) + + +def build_moe( + args: MoEArgs, dim: int, hidden_dim: int, moe_impl: str = "standard" +) -> nn.Module: + """Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP).""" + if moe_impl == "deepep": + from .moe_deepep import DeepEPMoE + + logger.info( + f"DeepEP MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}" + ) + return DeepEPMoE(moe_args=args, dim=dim, hidden_dim=hidden_dim) + + logger.info( + f"Standard MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}" + ) + return MoE(args, dim=dim, hidden_dim=hidden_dim) diff --git a/torchtitan/models/moe/moe_deepep.py b/torchtitan/models/moe/moe_deepep.py new file mode 100644 index 0000000000..54e3f0f2a3 --- /dev/null +++ b/torchtitan/models/moe/moe_deepep.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MoE with DeepEP backend for efficient expert-parallel communication.""" + +import torch + +from .moe import MoE, MoEArgs + + +class DeepEPMoE(MoE): + """ + Mixture of Experts with DeepEP communication. + + Inherits from MoE but overrides forward() to pass routing info to experts, + letting DeepEPExpertParallel hooks handle dispatch/combine. + """ + + def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): + super().__init__(moe_args, dim, hidden_dim) + # DeepEP doesn't use reorderer - routing handled by DeepEPExpertParallel + self.reorderer = None # pyrefly: ignore [bad-assignment] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with DeepEP communication. + + DeepEPExpertParallel hooks intercept experts() call and handle + dispatch/combine via deepep functions. + """ + bs, slen, dim = x.shape + x = x.view(-1, dim) + + top_scores, selected_experts_indices, num_tokens_per_expert = self.router( + x, self.expert_bias + ) + + if self.load_balance_coeff is not None: + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + + # Call experts with routing info - hooks handle DeepEP dispatch/combine + routed_output = self.experts( + x, + num_tokens_per_expert, + selected_experts_indices, + top_scores, + self.experts.num_experts, + ) + + out = self.shared_experts(x) if self.shared_experts is not None else None + + if out is None: + return routed_output.reshape(bs, slen, dim) + return (out + routed_output).reshape(bs, slen, dim) From 443876426087273ca14a1945b556dc16563db4e1 Mon Sep 17 00:00:00 2001 From: Salman Chishti <13schishti@gmail.com> Date: Fri, 19 Dec 2025 01:42:54 +0000 Subject: [PATCH 075/127] Fix pypa/gh-action-pypi-publish version to use SHA pinning (#2161) ## Summary Fix incorrect version reference for `pypa/gh-action-pypi-publish`. ## Problem A previous PR incorrectly changed the action reference from `release/v1` (valid branch) to `v1` (non-existent tag). The `v1` tag doesn't exist in the pypa/gh-action-pypi-publish repository. ## Solution Updated to use SHA pinning for release/v1.13: ```yaml uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1.13 ``` This follows [GitHub's security best practices](https://docs.github.com/en/actions/reference/security/secure-use#using-third-party-actions) for third-party actions by pinning to an immutable SHA. ## Files Changed - `.github/workflows/release.yml` --------- Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cdde1f92bc..1664acdc58 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -46,4 +46,4 @@ jobs: path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@v1 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1.13 From fd49b4b4e7f44575317fd489c33cbc802545273b Mon Sep 17 00:00:00 2001 From: Salman Chishti <13schishti@gmail.com> Date: Fri, 19 Dec 2025 01:43:31 +0000 Subject: [PATCH 076/127] Upgrade GitHub Actions for Node 24 compatibility (#2164) ## Summary Upgrade GitHub Actions to their latest versions to ensure compatibility with Node 24, as Node 20 will reach end-of-life in April 2026. ## Changes | Action | Old Version(s) | New Version | Release | Files | |--------|---------------|-------------|---------|-------| | `actions/checkout` | [`v3`](https://github.com/actions/checkout/releases/tag/v3) | [`v6`](https://github.com/actions/checkout/releases/tag/v6) | [Release](https://github.com/actions/checkout/releases/tag/v6) | lint.yaml | | `actions/setup-python` | [`v4`](https://github.com/actions/setup-python/releases/tag/v4) | [`v6`](https://github.com/actions/setup-python/releases/tag/v6) | [Release](https://github.com/actions/setup-python/releases/tag/v6) | lint.yaml | ## Context Per [GitHub's announcement](https://github.blog/changelog/2025-09-19-deprecation-of-node-20-on-github-actions-runners/), Node 20 is being deprecated and runners will begin using Node 24 by default starting March 4th, 2026. ### Why this matters - **Node 20 EOL**: April 2026 - **Node 24 default**: March 4th, 2026 - **Action**: Update to latest action versions that support Node 24 ### Security Note Actions that were previously pinned to commit SHAs remain pinned to SHAs (updated to the latest release SHA) to maintain the security benefits of immutable references. ### Testing These changes only affect CI/CD workflow configurations and should not impact application functionality. The workflows should be tested by running them on a branch before merging. Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/workflows/lint.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 327b0bec23..f23820d699 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -19,9 +19,9 @@ jobs: python-version: ['3.10'] steps: - name: Check out repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Update pip From 658f94cb9b27963784268185ed75c524f3b9b2be Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 18 Dec 2025 17:45:49 -0800 Subject: [PATCH 077/127] Expose common dataloader args (#2097) This diff introduces common dataloader args which are supported by statefuldataloader (and torch.utils.data dataloader). Users should be able to use them in their config files. I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc). --- tests/integration_tests/features.py | 15 ++ tests/unit_tests/test_dataloader.py | 153 ++++++++++++++++++ torchtitan/components/dataloader.py | 52 +++++- torchtitan/config/job_config.py | 40 +++++ .../experiments/vlm/datasets/mm_datasets.py | 22 ++- torchtitan/hf_datasets/text_datasets.py | 39 ++++- torchtitan/models/flux/flux_datasets.py | 38 ++++- 7 files changed, 332 insertions(+), 27 deletions(-) create mode 100644 tests/unit_tests/test_dataloader.py diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index fe51ab7cf7..8e16ecb4fb 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -557,6 +557,21 @@ def build_features_test_list() -> list[OverrideDefinitions]: "validation_tp_cp_pp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--training.dataloader.num_workers", + "2", + "--training.dataloader.pin_memory", + "--training.dataloader.persistent_workers", + "--training.dataloader.prefetch_factor", + "4", + ], + ], + "Dataloader kwargs (via CLI args)", + "dataloader_kwargs", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/tests/unit_tests/test_dataloader.py b/tests/unit_tests/test_dataloader.py new file mode 100644 index 0000000000..82625e5e06 --- /dev/null +++ b/tests/unit_tests/test_dataloader.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torch.utils.data import IterableDataset + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config import ConfigManager + + +class DummyDataset(IterableDataset): + """A simple dummy dataset for testing.""" + + def __iter__(self): + for i in range(100): + yield {"input": i}, i + + +class DummyTokenizer(BaseTokenizer): + """A dummy tokenizer for testing that implements BaseTokenizer interface.""" + + def __init__(self): + super().__init__() + self.eos_id = 2 + + def encode( + self, text: str, add_bos: bool = False, add_eos: bool = False + ) -> list[int]: + # Simple encoding: convert each character to its ASCII value + tokens = [ord(c) for c in text] + if add_bos: + tokens.insert(0, 1) # BOS token + if add_eos: + tokens.append(self.eos_id) + return tokens + + def decode(self, token_ids: list[int]) -> str: + # Simple decoding: convert ASCII values back to characters + return "".join(chr(t) for t in token_ids if t > 2) + + def get_vocab_size(self) -> int: + return 256 # ASCII range + + +class TestParallelAwareDataloader(unittest.TestCase): + def test_dataloader_yields_correct_batches(self): + """Test that the dataloader correctly yields batched data from the dataset.""" + dataset = DummyDataset() + batch_size = 4 + + dataloader = ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + batch_size=batch_size, + ) + + batches = list(dataloader) + + # DummyDataset yields 100 items, so we expect 25 batches of size 4 + self.assertEqual(len(batches), 25) + + # Check first batch structure and values + first_batch_input, first_batch_label = batches[0] + self.assertEqual(len(first_batch_input["input"]), batch_size) + self.assertEqual(len(first_batch_label), batch_size) + + # Verify first batch contains expected values (0, 1, 2, 3) + self.assertEqual(first_batch_input["input"].tolist(), [0, 1, 2, 3]) + self.assertEqual(first_batch_label.tolist(), [0, 1, 2, 3]) + + # Check last batch + last_batch_input, last_batch_label = batches[-1] + self.assertEqual(last_batch_input["input"].tolist(), [96, 97, 98, 99]) + self.assertEqual(last_batch_label.tolist(), [96, 97, 98, 99]) + + def test_validate_kwargs_rejects_invalid_kwargs(self): + """Test that passing invalid kwargs raises ValueError.""" + dataset = DummyDataset() + + with self.assertRaises(ValueError) as context: + ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + invalid_arg=42, + ) + + self.assertIn("Invalid dataloader kwargs", str(context.exception)) + self.assertIn("invalid_arg", str(context.exception)) + + def test_config_batch_size_overwritten_by_explicit_batch_size(self): + """Test that batch_size in config kwargs is overwritten by explicit batch_size.""" + dataset = DummyDataset() + + config_kwargs = {"batch_size": 2, "num_workers": 0} + + explicit_batch_size = 8 + + # Merge kwargs with explicit args taking precedence (same pattern as in dataset files) + dataloader_kwargs = { + **config_kwargs, + "batch_size": explicit_batch_size, + } + + dataloader = ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + **dataloader_kwargs, + ) + + # Verify that batch_size is the explicit one, not the config one + self.assertEqual(dataloader.batch_size, explicit_batch_size) + + def test_build_dataloader_with_job_config(self): + """Verify batch_size from job_config.training.local_batch_size is correctly used.""" + from torchtitan.hf_datasets.text_datasets import build_text_dataloader + + tokenizer = DummyTokenizer() + + config_manager = ConfigManager() + config = config_manager.parse_args( + [ + "--training.dataset", + "c4_test", + "--training.local_batch_size", + "8", + "--training.seq_len", + "512", + "--training.dataloader.num_workers", + "2", + ] + ) + + dataloader = build_text_dataloader( + tokenizer=tokenizer, + dp_world_size=1, + dp_rank=0, + job_config=config, + ) + + self.assertEqual(dataloader.batch_size, 8) + self.assertEqual(dataloader.num_workers, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 7a1c1fcad6..a1fc08e39f 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -6,9 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import inspect import pickle from abc import ABC, abstractmethod -from collections.abc import Callable from typing import Any from torch.distributed.checkpoint.stateful import Stateful @@ -16,6 +16,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.tools.logging import logger + # NOTE: This class deliberately inherits from `Exception` and not `StopIteration`. # According to PEP 479, raising a `StopIteration` or its subclass from within a # generator will wrap it in a `RuntimeError`. Since this exception is designed @@ -53,28 +54,63 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): dataset (IterableDataset): The dataset to iterate over. dp_rank: Data parallelism rank for this dataloader. dp_world_size: The world size of the data parallelism. - batch_size: The batch size to use for each iteration. - collate_fn: Optional function to collate samples in a batch. + **kwargs: Additional keyword arguments passed to StatefulDataLoader (e.g., + batch_size, collate_fn, num_workers, persistent_workers, prefetch_factor, + pin_memory). """ dp_rank: int dp_world_size: int - batch_size: int | None def __init__( self, dataset: IterableDataset, dp_rank: int, dp_world_size: int, - batch_size: int, - collate_fn: Callable | None = None, + **kwargs, ): + self._validate_kwargs(kwargs) + self.dp_world_size = dp_world_size self.dp_rank = dp_rank - self.batch_size = batch_size - super().__init__(dataset, batch_size, collate_fn=collate_fn) self._rank_id = f"dp_rank_{dp_rank}" + super().__init__(dataset, **kwargs) + + @staticmethod + def _validate_kwargs(kwargs: dict[str, Any]) -> None: + """Validate and sanitize kwargs passed to the dataloader. + + Args: + kwargs: Dictionary of keyword arguments to validate. This dict is + modified in-place to remove invalid combinations. + + Raises: + ValueError: If 'dataset' is in kwargs or if any invalid kwargs are passed. + """ + if "dataset" in kwargs: + raise ValueError( + "'dataset' should not be passed in kwargs; " + "it must be provided as the first positional argument." + ) + + sig = inspect.signature(StatefulDataLoader.__init__) + valid_kwargs = frozenset( + name for name in sig.parameters.keys() if name not in ("self", "dataset") + ) + invalid_kwargs = set(kwargs.keys()) - valid_kwargs + if invalid_kwargs: + raise ValueError( + f"Invalid dataloader kwargs: {invalid_kwargs}. " + f"Valid kwargs are: {sorted(valid_kwargs)}" + ) + + # persistent_workers and prefetch_factor are only valid when num_workers > 0. + # Removing them here if num_workers is 0 to avoid StatefulDataLoader errors + if kwargs.get("num_workers", 0) == 0: + kwargs.pop("persistent_workers", None) + kwargs.pop("prefetch_factor", None) + def state_dict(self) -> dict[str, Any]: # Store state only for dp rank to avoid replicating the same state across other dimensions. return { diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 504f00419f..108c38efba 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -197,6 +197,40 @@ class LRScheduler: """ +@dataclass +class DataLoader: + """ + Configuration for PyTorch DataLoader settings. + + These settings are passed directly to StatefulDataLoader. + + Note: + persistent_workers and prefetch_factor are only valid if num_workers > 0. + + Example (TOML config file): + [training.dataloader] + num_workers = 4 + pin_memory = true + persistent_workers = true + prefetch_factor = 2 + """ + + num_workers: int = 0 + """Number of worker processes for data loading.""" + + persistent_workers: bool = False + """Keep workers alive between epochs. Only valid when num_workers > 0.""" + + pin_memory: bool = False + """Copy tensors to CUDA pinned memory before returning them.""" + + prefetch_factor: int | None = None + """ + Number of batches loaded in advance by each worker. Only valid when num_workers > 0. + Default is 2 when num_workers > 0, otherwise None. + """ + + @dataclass class Training: dataset: str = "c4_test" @@ -262,6 +296,9 @@ class Training: many temporary files. """ + dataloader: DataLoader = field(default_factory=DataLoader) + """DataLoader configuration""" + @dataclass class Parallelism: @@ -912,6 +949,9 @@ class Validation: WARNING: When setting to -1 there could be hangs due to mismatch among ranks """ + dataloader: DataLoader = field(default_factory=DataLoader) + """DataLoader configuration""" + def __post_init__(self): assert ( self.steps > 0 or self.steps == -1 diff --git a/torchtitan/experiments/vlm/datasets/mm_datasets.py b/torchtitan/experiments/vlm/datasets/mm_datasets.py index 2a6a2000b9..2234496983 100644 --- a/torchtitan/experiments/vlm/datasets/mm_datasets.py +++ b/torchtitan/experiments/vlm/datasets/mm_datasets.py @@ -11,6 +11,7 @@ It supports both streaming and non-streaming datasets from HuggingFace. """ +from dataclasses import asdict from typing import Any, Callable import torch @@ -381,14 +382,14 @@ def build_mm_dataloader( """Build a data loader for multimodal datasets. Args: - dp_world_size: Data parallel world size - dp_rank: Data parallel rank - tokenizer: Tokenizer for text processing - job_config: Job configuration - infinite: Whether to loop infinitely + dp_world_size: Data parallel world size. + dp_rank: Data parallel rank. + tokenizer: Tokenizer for text processing. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop infinitely. Returns: - DataLoader with appropriate parallelism handling + DataLoader with appropriate parallelism handling. """ dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -429,12 +430,17 @@ def build_mm_dataloader( special_tokens=special_tokens, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + "collate_fn": collate_fn, + } + base_dataloader = ParallelAwareDataloader( dataset=dataset, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, - collate_fn=collate_fn, + **dataloader_kwargs, ) return base_dataloader diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 63790b8862..586abf2bce 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import asdict from functools import partial from typing import Any, Callable @@ -172,7 +173,15 @@ def build_text_dataloader( job_config: JobConfig, infinite: bool = True, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + tokenizer: Tokenizer to use for encoding text. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -188,11 +197,16 @@ def build_text_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=hf_ds, + hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) @@ -203,7 +217,15 @@ def build_text_validation_dataloader( job_config: JobConfig, infinite: bool = False, ) -> ParallelAwareDataloader: - """Build a validation data loader for HuggingFace datasets.""" + """Build a validation data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + tokenizer: Tokenizer to use for encoding text. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.validation.dataset dataset_path = job_config.validation.dataset_path batch_size = job_config.validation.local_batch_size @@ -219,9 +241,14 @@ def build_text_validation_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.validation.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=hf_ds, + hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) diff --git a/torchtitan/models/flux/flux_datasets.py b/torchtitan/models/flux/flux_datasets.py index 906b669001..5b6492dff1 100644 --- a/torchtitan/models/flux/flux_datasets.py +++ b/torchtitan/models/flux/flux_datasets.py @@ -6,6 +6,7 @@ import itertools import math +from dataclasses import asdict from typing import Any, Callable, Optional import numpy as np @@ -316,7 +317,15 @@ def build_flux_dataloader( tokenizer: FluxTokenizer | None, infinite: bool = True, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + job_config: Job configuration containing dataset and DataLoader settings. + tokenizer: Tokenizer (kept for compatibility, not used). + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -334,11 +343,16 @@ def build_flux_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( dataset=ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) @@ -402,7 +416,16 @@ def build_flux_validation_dataloader( generate_timestamps: bool = True, infinite: bool = False, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a validation data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + job_config: Job configuration containing dataset and DataLoader settings. + tokenizer: Tokenizer (kept for compatibility, not used). + generate_timestamps: Whether to generate timesteps for validation. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.validation.dataset dataset_path = job_config.validation.dataset_path batch_size = job_config.validation.local_batch_size @@ -421,9 +444,14 @@ def build_flux_validation_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.validation.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=ds, + ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) From b786a3d9c1a011193cbee5b03d9c0882c2f30450 Mon Sep 17 00:00:00 2001 From: Walker <33346657+EquationWalker@users.noreply.github.com> Date: Sat, 20 Dec 2025 03:09:00 +0800 Subject: [PATCH 078/127] Replace `logger.warn()` to `logger.warning()` , allow `log_validation` to log `extra_metrics` and expose common wandb args (#2166) 1. Replace `logger.warn()` to `logger.warning()` 2. allow `log_validation` to log `extra_metrics` 3. expose common wandb init args, it is userful when resume training. --- torchtitan/components/checkpoint.py | 2 +- torchtitan/components/metrics.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 7928f514ba..7a7149a061 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -215,7 +215,7 @@ def __init__( if self.ft_manager and not self.enable_ft_dataloader_checkpoints: # pyrefly: ignore [deprecated] - logger.warn( + logger.warning( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." ) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 6f50337473..93c1e3b10f 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -155,6 +155,13 @@ def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None): entity=os.getenv("WANDB_TEAM", None), project=os.getenv("WANDB_PROJECT", "torchtitan"), name=os.getenv("WANDB_RUN_NAME", None), + id=os.getenv("WANDB_RUN_ID", None), + notes=os.getenv("WANDB_RUN_NOTES", None), + tags=os.getenv("WANDB_RUN_TAGS", None), + group=os.getenv("WANDB_RUN_GROUP", None), + job_type=os.getenv("WANDB_RUN_JOB_TYPE", None), + resume_from=os.getenv("WANDB_RESUME_FROM", None), + fork_from=os.getenv("WANDB_FORK_FROM", None), dir=log_dir, config=job_config.to_dict(), ) @@ -461,7 +468,9 @@ def log( self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() - def log_validation(self, loss: float, step: int): + def log_validation( + self, loss: float, step: int, extra_metrics: dict[str, Any] | None = None + ): time_delta = time.perf_counter() - self.time_last_log device_mem_stats = self.device_memory_monitor.get_peak_stats() @@ -479,6 +488,10 @@ def log_validation(self, loss: float, step: int): "validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, "validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct, } + + if extra_metrics: + metrics.update(extra_metrics) + self.logger.log(metrics, step) color = self.color From b21555f3d2a3c4def9916f7a0dbde3b9b3b73e67 Mon Sep 17 00:00:00 2001 From: Salman Chishti <13schishti@gmail.com> Date: Fri, 19 Dec 2025 19:32:47 +0000 Subject: [PATCH 079/127] Add Dependabot for GitHub Actions updates (#2163) ## Summary Add Dependabot configuration to automatically keep GitHub Actions up to date. Here's some more information about Dependabot: https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot ## Changes - Added `.github/dependabot.yml` with weekly checks for GitHub Actions updates ## Context As discussed in #2161 ([comment](https://github.com/pytorch/torchtitan/pull/2161#issuecomment-3667526716)), adding Dependabot to automatically manage GitHub Actions updates going forward. ## Why Dependabot will automatically create PRs when new versions of GitHub Actions are available, helping to: - Keep CI/CD workflows secure with the latest patches - Get new features and improvements - Maintain compatibility with GitHub's infrastructure Each action update will be proposed as a separate PR for individual review and testing. --------- Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/dependabot.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..f6faee6938 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + github-actions: + patterns: + - "*" From 1bd2548b14da014b1ec560830f8bdefb6ca568f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:36:55 -0800 Subject: [PATCH 080/127] Bump tj-actions/changed-files from d6e91a2266cdb9d62096cebf1e8546899c6aa18f to e0021407031f5be11a464abee9a0776171c79891 in the github-actions group (#2167) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps the github-actions group with 1 update: [tj-actions/changed-files](https://github.com/tj-actions/changed-files). Updates `tj-actions/changed-files` from d6e91a2266cdb9d62096cebf1e8546899c6aa18f to e0021407031f5be11a464abee9a0776171c79891
    Changelog

    Sourced from tj-actions/changed-files's changelog.

    Changelog

    47.0.0 - (2025-09-13)

    🚀 Features

    ➖ Remove

    • Commit and push step from build job (#2538) (be393a9) - (Tonye Jack)

    🔄 Update

    • Updated README.md (#2592)

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@​users.noreply.github.com> (3dbc1e1) - (github-actions[bot])

    • Updated README.md (#2591)

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@​users.noreply.github.com> (b1ccff8) - (github-actions[bot])

    • Updated README.md (#2574)

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@​users.noreply.github.com> (050a3d3) - (github-actions[bot])

    📚 Documentation

    • Update link to glob patterns (#2590) (a892f50) - (Tonye Jack)
    • Add Jellyfrog as a contributor for code, and doc (#2573) (f000a9b) - (allcontributors[bot])

    🧪 Testing

    • Manual triggered workflows (#2637) (c2ca249) - (Tonye Jack)

    ⚙️ Miscellaneous Tasks

    • deps-dev: Bump jest from 30.0.5 to 30.1.3 (#2655) (9a67555) - (dependabot[bot])
    • deps: Bump tj-actions/git-cliff from 2.1.0 to 2.2.0 (#2660) (b67e30d) - (dependabot[bot])
    • deps: Bump github/codeql-action from 3.30.2 to 3.30.3 (#2661) (62aef42) - (dependabot[bot])
    • deps: Bump github/codeql-action from 3.29.11 to 3.30.2 (#2659) (e874f3c) - (dependabot[bot])
    • deps: Bump actions/setup-node from 4.4.0 to 5.0.0 (#2656) (8c14441) - (dependabot[bot])
    • deps-dev: Bump @​types/node from 24.3.0 to 24.3.1 (#2657) (e995ac4) - (dependabot[bot])
    • deps-dev: Bump @​types/node from 24.2.1 to 24.3.0 (#2649) (3b04099) - (dependabot[bot])
    • deps: Bump github/codeql-action from 3.29.9 to 3.29.11 (#2651) (e7b6c97) - (dependabot[bot])
    • deps: Bump tj-actions/git-cliff from 2.0.2 to 2.1.0 (#2648) (765d62b) - (dependabot[bot])
    • deps: Bump github/codeql-action from 3.29.8 to 3.29.9 (#2647) (2036da1) - (dependabot[bot])
    • deps: Bump github/codeql-action from 3.29.7 to 3.29.8 (#2644) (239aef8) - (dependabot[bot])
    • deps-dev: Bump @​types/node from 24.2.0 to 24.2.1 (#2645) (a7d5f5f) - (dependabot[bot])
    • deps: Bump actions/checkout from 4.2.2 to 5.0.0 (#2646) (5107f3a) - (dependabot[bot])
    • deps-dev: Bump @​types/node from 24.1.0 to 24.2.0 (#2640) (f963b3f) - (dependabot[bot])
    • deps: Bump actions/download-artifact from 4.3.0 to 5.0.0 (#2641) (f956744) - (dependabot[bot])

    ... (truncated)

    Commits
    • e002140 chore(deps): bump actions/checkout from 6.0.0 to 6.0.1 (#2729)
    • 01ddfae chore(deps): bump @​actions/core from 1.11.1 to 2.0.0 (#2736)
    • a364493 chore(deps-dev): bump prettier from 3.7.1 to 3.7.4 (#2731)
    • 45a2aae chore(deps): bump actions/setup-node from 6.0.0 to 6.1.0 (#2730)
    • a4f6de3 chore(deps): bump github/codeql-action from 4.31.5 to 4.31.7 (#2732)
    • 95fbe9b chore(deps): bump peter-evans/create-pull-request from 7.0.9 to 8.0.0 (#2735)
    • b3b9724 chore(deps-dev): bump ts-jest from 29.4.5 to 29.4.6 (#2727)
    • 503bc3e chore(deps): bump @​actions/exec from 1.1.1 to 2.0.0 (#2737)
    • 3e9e5a2 chore(deps-dev): bump @​types/node from 24.10.1 to 25.0.0 (#2738)
    • 2b6c719 chore(deps): bump yaml from 2.8.1 to 2.8.2 (#2724)
    • Additional commits viewable in compare view

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
    Dependabot commands and options
    You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore major version` will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself) - `@dependabot ignore minor version` will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself) - `@dependabot ignore ` will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself) - `@dependabot unignore ` will remove all of the ignore conditions of the specified dependency - `@dependabot unignore ` will remove the ignore condition of the specified dependency and ignore conditions
    Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/lint.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index f23820d699..dece6f0804 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -33,6 +33,6 @@ jobs: pre-commit install-hooks - name: Get changed files id: changed-files - uses: tj-actions/changed-files@d6e91a2266cdb9d62096cebf1e8546899c6aa18f # v45.0.6 + uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.6 - name: Lint modified files run: pre-commit run --files ${{ steps.changed-files.outputs.all_changed_files }} From 4b3d25a0ce2a60276010d23fd0047676767648a3 Mon Sep 17 00:00:00 2001 From: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Date: Mon, 22 Dec 2025 14:30:09 -0800 Subject: [PATCH 081/127] Multiprocess simple RL loop (#2158) Borrowed part of the implementations from deterministic RL https://github.com/pytorch/torchtitan/pull/1975 by @bwasti, set up simple RL running on multiple process. Features added: - set up train and generator as Actors; potentially using an unified model definition from Torchtitan. - integrate with Monarch to run Torchtitan trainer on multiple processes using DDP, and vLLM generator using TP. Added TODOs here: https://github.com/pytorch/torchtitan/blob/eb601cfc924b379b589af95bffee2a3a56f0a67f/torchtitan/experiments/rl/unified/README.md#todo Command to run: ``` VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` --- torchtitan/experiments/rl/unified/README.md | 36 +- torchtitan/experiments/rl/unified/__init__.py | 4 +- .../rl/unified/actors/generator.py | 448 ++++++++++++++++++ .../experiments/rl/unified/actors/trainer.py | 136 ++++++ torchtitan/experiments/rl/unified/infer.py | 4 +- .../rl/unified/{ => models}/attention.py | 0 .../{utils.py => models/parallelism_utils.py} | 31 ++ .../experiments/rl/unified/models/utils.py | 147 ++++++ .../rl/unified/{ => models}/vllm_wrapper.py | 39 +- .../rl/unified/simple_rl_multiprocess.py | 184 +++++++ .../experiments/rl/vllm_compat/simple_rl.py | 9 +- 11 files changed, 982 insertions(+), 56 deletions(-) create mode 100644 torchtitan/experiments/rl/unified/actors/generator.py create mode 100644 torchtitan/experiments/rl/unified/actors/trainer.py rename torchtitan/experiments/rl/unified/{ => models}/attention.py (100%) rename torchtitan/experiments/rl/unified/{utils.py => models/parallelism_utils.py} (73%) create mode 100644 torchtitan/experiments/rl/unified/models/utils.py rename torchtitan/experiments/rl/unified/{ => models}/vllm_wrapper.py (87%) create mode 100644 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 5cea3918ae..fa54a936da 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -44,25 +44,45 @@ rm -rf build dist *.egg-info uv pip uninstall -y vllm # Rebuild vLLM from source with CUDA 12.4 -pip install -e . +uv pip install -e . ``` -3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. - +3. Download Qwen/Qwen3-0.6B checkpoint from HuggingFace and put into `torchtitan/experiments/rl/example_checkpoint` folder. +``` +python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... +``` 4. Run inference: ``` -python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` Run with TP: (work in progress) ``` -python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B --tensor-parallel-size 2 ``` +5. Run simple rl loop +``` +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +``` +Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, +which uses a unified model definition for trainer and generator. + ## TODO -1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. -2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. -3. Leverage batch-invariant kernels into model definition. +Work on batch invariance: +1. Integrate with simple_rl_multiprocess.py to run end-to-end RL with one canonical model definition(UNIFIED mode). +2. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +3. Leverage batch-invariant kernels into model definition. + +Work on the RL loop: +1. Design trainer API and integrate with [train.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L475) +2. Remove hardcoded configs and dependency on Qwen3 Model. Use torchtitan's config/TrainSpec instead, to work with any model. +3. Need to load the gsm8k dataset using TorchTitan dataset. +4. Need to properly implement weight saving and loading using TorchTitan's checkpoint mechanism, or use TorchStore. Also need to + replace `vllm_to_torchtitan` and `torchtitan_to_vllm` calls to TorchTitan [state dict adaptor](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/qwen3/model/state_dict_adapter.py). +5. Right now we only support trainer run on multiple processes using DDP, and generator using TP, need to onboard more parallelism. +6. Right now we only support VLLM_COMPAT mode to achieve batch invariance and bitwise determinism, need to support UNIFIED mode. +7. In the longer term, need to add trajectory queue to achieve async, right now trainer and generator are running synchronously. diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py index 6c34556112..d5cfa6047d 100644 --- a/torchtitan/experiments/rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -14,8 +14,8 @@ from torchtitan.protocols.train_spec import get_train_spec, TrainSpec from vllm.logger import init_logger -from .utils import create_parallel_dims_from_vllm_config -from .vllm_wrapper import TorchTitanVLLMModelWrapper +from .models.parallelism_utils import create_parallel_dims_from_vllm_config +from .models.vllm_wrapper import TorchTitanVLLMModelWrapper logger = init_logger(__name__) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py new file mode 100644 index 0000000000..d0ee5cf38f --- /dev/null +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import os + +from dataclasses import dataclass +from typing import List + +import torch +from monarch.actor import Actor, endpoint +from safetensors.torch import save_file + +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + compute_grpo_advantages, + compute_grpo_advantages_stable, + math_reward_function, + trivial_reward_function, +) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +logger = logging.getLogger(__name__) + + +@dataclass +class TrajectoryData: + """ + Data from one generation batch. + + Attributes: + policy_version: Version of policy that produced this batch + completions: List of completion strings + vllm_token_ids: List of token ID lists for each completion + vllm_token_log_probs: List of per-token log prob lists + prompt_token_ids: List of prompt token ID lists + rewards: Computed rewards for each completion + advantages: Computed advantages for each completion + """ + + policy_version: int + completions: List[str] + vllm_token_ids: List[List[int]] + vllm_token_log_probs: List[List[float]] + prompt_token_ids: List[List[int]] + rewards: torch.Tensor + advantages: torch.Tensor + + +class VLLMRolloutEngine: + """ + vLLM engine for fast rollouts with weight updates. + + Note: vLLM loads from model_config.model path, so we create a temporary + directory with updated weights and restart the engine. This is faster than + recreating temp dirs repeatedly and handles config/tokenizer files properly. + + Args: + model_path: Path to HuggingFace model (for config/tokenizer) + temp_checkpoint_dir: Directory to save temporary weight checkpoints + """ + + def __init__( + self, + model_path: str, + temp_checkpoint_dir: str = "./converted", + tp_size: int = 1, + ): + self.base_model_path = model_path + self.temp_model_dir = os.path.abspath( + os.path.join(temp_checkpoint_dir, "vllm_temp_model") + ) + os.makedirs(self.temp_model_dir, exist_ok=True) + + import glob + + # Copy config/tokenizer files from base model to temp dir + import shutil + + for file in [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "merges.txt", + "vocab.json", + ]: + src = os.path.join(model_path, file) + if os.path.exists(src): + shutil.copy2(src, self.temp_model_dir) + + # Copy the original model shard files if they exist + # We'll overwrite these with our single model.safetensors later + for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): + dst = os.path.join(self.temp_model_dir, os.path.basename(shard_file)) + shutil.copy2(shard_file, dst) + + # Copy index file if it exists + index_file = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file): + shutil.copy2(index_file, self.temp_model_dir) + + self.llm = None + self.tp_size = tp_size + logger.info("vLLM rollout engine initialized (will load on first use)") + + def update_weights(self, vllm_compat_state: dict) -> None: + """ + Update vLLM model weights from vLLM-compat state dict. + + This converts weights to vLLM format, saves them, and reloads using + vLLM's reload_weights() API after updating the model path config. + + Args: + vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + """ + # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) + vllm_state = torchtitan_to_vllm(vllm_compat_state) + + # Save to temp model directory + import os + + checkpoint_path = os.path.join(self.temp_model_dir, "model.safetensors") + + # Update the shard files that vLLM will actually load + # We need to split our weights to match the original 2-shard structure + import glob + import json + + shard_files = sorted( + glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) + ) + index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") + + # TODO: need to replace this with Torchtitan's checkpoint save and load + # right now we hardcoded to work with 2 safe tensor files which we only + # tested on Qwen3 1.7B model. In the longer term, need to use TorchStore + # to achieve the weight communication. + if len(shard_files) == 2 and os.path.exists(index_file): + # Load the index to see which weights go in which shard + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + + # Split weights according to the index + shard1_weights = {} + shard2_weights = {} + + for key, value in vllm_state.items(): + shard_file = weight_map.get(key, shard_files[0]) + if "model-00001-of-00002" in shard_file: + shard1_weights[key] = value + else: + shard2_weights[key] = value + + # Ensure weights stay in bfloat16 + shard1_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard1_weights.items() + } + shard2_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard2_weights.items() + } + + # Save to the shard files + save_file(shard1_weights, shard_files[0]) + save_file(shard2_weights, shard_files[1]) + else: + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) + + # First time: create the engine + if self.llm is None: + # Disable distributed execution to avoid NCCL conflicts in Monarch actors + # Use single GPU mode + import os + + os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + self.llm = LLM( + model=self.temp_model_dir, + trust_remote_code=True, + max_model_len=2048, + dtype="bfloat16", + gpu_memory_utilization=0.1, # Reduced from 0.5 + seed=42, # Fixed seed for determinism + enforce_eager=True, + tensor_parallel_size=self.tp_size, # Explicitly single GPU + ) + logger.info("Created new vLLM engine") + else: + # Use collective_rpc to call reload_weights on all workers + # This reloads weights from temp_model_dir without recreating the engine + self.llm.collective_rpc("reload_weights") + + @torch.no_grad() + def generate( + self, + prompt_texts: list[str], + max_new_tokens: int = 20, + temperature: float = 1.0, + n_samples_per_prompt: int = 4, + ) -> tuple[ + list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] + ]: + """ + Generate samples using vLLM. + + Args: + prompt_texts: List of prompt strings + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + n_samples_per_prompt: Number of samples per prompt + + Returns: + completions: List of completion strings + log_probs: [batch] - Sum of log probs for each completion + token_ids: List of token ID lists for each completion (generated tokens only) + token_log_probs: List of per-token log prob lists for each completion + prompt_token_ids: List of prompt token ID lists for each completion + """ + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + n=n_samples_per_prompt, + seed=42, + logprobs=1, + prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + ) + + outputs = self.llm.generate(prompt_texts, sampling_params) + + # Extract completions and log probs + completions = [] + log_probs_list = [] + token_ids_list = [] + token_log_probs_list = [] + prompt_token_ids_list = [] + + for output in outputs: + # Extract prompt token IDs from the output + prompt_token_ids = output.prompt_token_ids + + for sample in output.outputs: + completions.append(sample.text) + + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) + + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) + + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) + + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) + + log_probs = torch.tensor(log_probs_list, dtype=torch.float32) + + return ( + completions, + log_probs, + token_ids_list, + token_log_probs_list, + prompt_token_ids_list, + ) + + def __del__(self): + """Cleanup vLLM engine.""" + if hasattr(self, "llm"): + del self.llm + torch.cuda.empty_cache() + + +class GeneratorState: + """States for the Generator's state machine.""" + + READY_TO_GENERATE = "READY_TO_GENERATE" + READY_TO_UPDATE = "READY_TO_UPDATE" + + +class Generator(Actor): + """ + Generates rollouts using vLLM engine. + + Maintains a vLLM engine that is synchronized with the Trainer + via weight sync. Generates completions for given prompts and + computes rewards/advantages. + + Args: + model_path: Path to HuggingFace model + prompt_texts: List of prompt strings + expected_answers: List of expected answers + group_size: Number of samples per prompt + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + use_real_dataset: Whether using real dataset (GSM8K) + grpo_beta: Beta for GRPO advantages + use_stable_grpo: Whether to use stable GRPO + tp_size: Tensor Parallel size + """ + + def __init__( + self, + model_path: str, + prompt_texts: List[str], + expected_answers: List[str], + group_size: int = 8, + max_new_tokens: int = 20, + temperature: float = 1.0, + use_real_dataset: bool = False, + grpo_beta: float = 0.1, + use_stable_grpo: bool = False, + tp_size: int = 1, + ): + self.model_path = model_path + self.prompt_texts = prompt_texts + self.expected_answers = expected_answers + self.group_size = group_size + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.use_real_dataset = use_real_dataset + self.grpo_beta = grpo_beta + self.use_stable_grpo = use_stable_grpo + self.tp_size = tp_size + + # Initialize vLLM engine + self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + + # State machine + self.state = GeneratorState.READY_TO_UPDATE + self.cond = asyncio.Condition() + self.policy_version = 0 + + # Reward function + self.reward_fn = ( + math_reward_function if use_real_dataset else trivial_reward_function + ) + + logger.info("Generator initialized with vLLM engine") + + @endpoint + async def generate(self) -> None: + """Generate trajectories and compute rewards/advantages.""" + logger.info( + f"{os.getpid()=} Generating start generate (policy v{self.policy_version})..." + ) + async with self.cond: + # Wait until ready to generate (weights have been updated) + await self.cond.wait_for( + lambda: self.state == GeneratorState.READY_TO_GENERATE + ) + + # Generate samples using vLLM + ( + completions, + vllm_log_probs, + vllm_token_ids, + vllm_token_log_probs, + prompt_token_ids, + ) = self.vllm_engine.generate( + self.prompt_texts, + self.max_new_tokens, + self.temperature, + n_samples_per_prompt=self.group_size, + ) + + # Compute rewards + rewards = self.reward_fn( + completions, self.expected_answers, self.group_size + ) + + # Normalize rewards + reward_mean = rewards.mean() + reward_std = rewards.std() + if reward_std > 1e-8: + rewards_normalized = (rewards - reward_mean) / reward_std + else: + rewards_normalized = rewards - reward_mean + + # Compute advantages using GRPO + if self.use_stable_grpo: + advantages = compute_grpo_advantages_stable( + rewards_normalized, self.group_size + ) + else: + advantages = compute_grpo_advantages( + rewards_normalized, self.group_size, beta=self.grpo_beta + ) + + # Create trajectory data + trajectory = TrajectoryData( + policy_version=self.policy_version, + completions=completions, + vllm_token_ids=vllm_token_ids, + vllm_token_log_probs=vllm_token_log_probs, + prompt_token_ids=prompt_token_ids, + rewards=rewards, + advantages=advantages, + ) + + # Signal ready for update + self.state = GeneratorState.READY_TO_UPDATE + self.cond.notify_all() + + logger.info( + f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." + ) + return trajectory + + @endpoint + async def update(self, version: int, vllm_compat_state: dict) -> None: + """Update generate weights. + + Args: + version: New policy version number + vllm_compat_state: vLLM-compatible state dict + """ + async with self.cond: + self.vllm_engine.update_weights(vllm_compat_state) + # Update version and state + self.policy_version = version + self.state = GeneratorState.READY_TO_GENERATE + self.cond.notify_all() + logger.info( + f"{os.getpid()=} Generator updating weights to policy v{version}..." + ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py new file mode 100644 index 0000000000..9ffb9f0f0a --- /dev/null +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import Any, Optional + +import torch +from monarch.actor import Actor, endpoint +from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData +from torchtitan.experiments.rl.unified.models.parallelism_utils import ( + create_trainer_parallel_dims, +) +from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + compute_policy_gradient_loss_vllm, +) +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) + +logger = logging.getLogger(__name__) + + +class Trainer(Actor): + """ + Updates policy based on collected trajectories. + + Run model forward on trajectories, computes loss, and run backward. + + Args: + titan_checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model + learning_rate: Learning rate for optimizer + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model + """ + + def __init__( + self, + titan_checkpoint_path: str, + model_path: str, + learning_rate: float = 1e-5, + model_mode: str = ModelMode.VLLM_COMPAT, + ddp_size: int = 1, + tp_size: int = 1, + ): + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(local_rank) + + self.model = load_model( + titan_checkpoint_path, model_path, model_mode=model_mode + ) + self.ddp_size = ddp_size + self.tp_size = tp_size + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) + + # apply PT-D Parallelism + # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from train_spec + from torchtitan.models.llama3.infra.parallelize import apply_ddp + + apply_ddp( + self.model, + self.parallel_dims.get_mesh("dp_replicate"), + enable_compile=False, + ) + + self.model = self.model.to(device) + self.model.train() + + # Optimizer + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) + self.policy_version = 0 + self.generator: Optional[Any] = None + + logger.info("Trainer initialized with TorchTitan model") + + @endpoint + async def get_weights(self) -> dict: + """Get vLLM-compatible weights for generator. + + Returns: + vLLM-compatible state dict + """ + titan_state = self.model.state_dict() + vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + return vllm_compat_state + + @endpoint + async def step(self, trajectory: TrajectoryData) -> dict: + """Perform one training step. + + Returns: + Training metrics + """ + logger.info( + f"{os.getpid()=} Trainer starts to train {self.policy_version} on traj:" + ) + # Compute loss + loss, loss_metrics = compute_policy_gradient_loss_vllm( + self.model, + trajectory.vllm_token_ids, + trajectory.vllm_token_log_probs, + trajectory.prompt_token_ids, + trajectory.advantages, + kl_coef=0.1, + ) + + # Update weights + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + + self.policy_version += 1 + + # TODO: save dcp checkpoint to file here instead of sending weight dicts + + # Return metrics + metrics = { + "loss": loss.item(), + "reward_mean": trajectory.rewards.mean().item(), + "reward_std": trajectory.rewards.std().item(), + "advantage_mean": trajectory.advantages.mean().item(), + "advantage_std": trajectory.advantages.std().item(), + "sample_completion": trajectory.completions[0][:80], + "policy_version": self.policy_version, + **loss_metrics, + } + logger.info(f"{os.getpid()=} Trainer finish step {self.policy_version}") + return metrics diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 19770ecc22..8c0b82edc0 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -8,7 +8,7 @@ import argparse # Import unified module - this automatically registers TorchTitan models with vLLM -from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 +from torchtitan.experiments.rl import unified # noqa: F401 from vllm import LLM, SamplingParams from vllm.logger import init_logger @@ -25,7 +25,7 @@ def parse_args(): parser.add_argument( "--model_ckpt_path", type=str, - default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint", + default="torchtitan/experiments/rl/example_checkpoint", help="Path to TorchTitan checkpoint directory", ) parser.add_argument( diff --git a/torchtitan/experiments/rl/unified/attention.py b/torchtitan/experiments/rl/unified/models/attention.py similarity index 100% rename from torchtitan/experiments/rl/unified/attention.py rename to torchtitan/experiments/rl/unified/models/attention.py diff --git a/torchtitan/experiments/rl/unified/utils.py b/torchtitan/experiments/rl/unified/models/parallelism_utils.py similarity index 73% rename from torchtitan/experiments/rl/unified/utils.py rename to torchtitan/experiments/rl/unified/models/parallelism_utils.py index e997c387d9..bac266149d 100644 --- a/torchtitan/experiments/rl/unified/utils.py +++ b/torchtitan/experiments/rl/unified/models/parallelism_utils.py @@ -12,6 +12,8 @@ """ import torch.distributed as dist +from torchtitan.config.job_config import Comm +from torchtitan.distributed import utils as dist_utils from torchtitan.distributed.parallel_dims import ParallelDims from vllm.config import VllmConfig @@ -61,3 +63,32 @@ def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDi ) return parallel_dims + + +def create_trainer_parallel_dims(ddp_size, tp_size) -> ParallelDims: + """ + Create ParallelDims for trainer with specified DDP and TP sizes. + + This function initializes the distributed process group and creates a ParallelDims + object configured for for trainer SPMD workers. + + Args: + ddp_size: Data parallel (DDP) replicate size + tp_size: Tensor parallel size + + Returns: + ParallelDims object with trainer parallelism settings + """ + world_size = dist_utils.init_distributed( + Comm(), + ) + return ParallelDims( + dp_replicate=ddp_size, + dp_shard=1, + tp=tp_size, + cp=1, + pp=1, + ep=1, + etp=1, + world_size=world_size, + ) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py new file mode 100644 index 0000000000..f954cf16e5 --- /dev/null +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum + +import torch +from safetensors.torch import load_file + +from torchtitan.experiments.rl.unified.models.attention import VLLMAttention + +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig + +logger = logging.getLogger(__name__) + + +class ModelMode(str, Enum): + """ + Enum defining which TorchTitan model to use. + + Attributes: + UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified + training and inference. + VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, + ensuring bitwise determinism between training and inference. + STANDARD: Plain TorchTitan model without any modifications. + """ + + UNIFIED = "unified" + VLLM_COMPAT = "vllm_compat" + STANDARD = "standard" + + +def replace_with_vllm_attention(model): + """ + Replace TorchTitan attention with vLLM paged attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + """ + if not hasattr(model, "layers"): + raise AttributeError( + f"Model {type(model).__name__} must have .layers attribute" + ) + + model_args = model.model_args + for layer_name, layer in model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + vllm_attn = VLLMAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) + head_dim=model_args.head_dim, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, + ) + + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(model.layers)} layers)" + ) + + +def load_model( + checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT +): + """ + Load TorchTitan model from checkpoint. + + Args: + checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model (for config) + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model + + Returns: + model: Loaded TorchTitan model + """ + # Load HuggingFace config + # TODO: do not depend on transformers.AutoConfig, use qwen_args directly + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Create model args + model_args = Qwen3ModelArgs( + dim=hf_config.hidden_size, + n_layers=hf_config.num_hidden_layers, + n_heads=hf_config.num_attention_heads, + n_kv_heads=hf_config.num_key_value_heads, + vocab_size=hf_config.vocab_size, + head_dim=getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ), + hidden_dim=hf_config.intermediate_size, + norm_eps=hf_config.rms_norm_eps, + rope_theta=hf_config.rope_theta, + max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), + qk_norm=True, + depth_init=True, + eos_id=getattr(hf_config, "eos_token_id", 151645), + ) + + # state_dict is in standard TorchTitan format (w1, w2, w3) + state_dict = load_file(checkpoint_path) + + if model_mode == ModelMode.UNIFIED: + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + replace_with_vllm_attention(model) + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=True) + elif model_mode == ModelMode.VLLM_COMPAT: + # Create and load model that has bitwise determinism between training and inference + from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( + Qwen3VLLMCompatModel, + ) + + model = Qwen3VLLMCompatModel(model_args) + # Convert to vLLM-compat format (merged gate_up_proj, down_proj) + vllm_compat_state = torchtitan_to_vllm_compat(state_dict) + model.load_state_dict(vllm_compat_state, strict=False) + else: + # Use standard TorchTitan model + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=False) + + model.to(torch.bfloat16) + + return model diff --git a/torchtitan/experiments/rl/unified/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py similarity index 87% rename from torchtitan/experiments/rl/unified/vllm_wrapper.py rename to torchtitan/experiments/rl/unified/models/vllm_wrapper.py index e92903c744..4faa9037dc 100644 --- a/torchtitan/experiments/rl/unified/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -21,7 +21,7 @@ StateDictOptions, ) -from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention +from torchtitan.experiments.rl.unified.models.utils import replace_with_vllm_attention from torchtitan.models.qwen3.model.model import precompute_rope_cache from torchtitan.protocols.model import BaseModelArgs, ModelProtocol from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter @@ -30,7 +30,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from .utils import create_parallel_dims_from_vllm_config +from .parallelism_utils import create_parallel_dims_from_vllm_config logger = init_logger(__name__) @@ -83,7 +83,7 @@ def __init__( base=self.config.rope_theta, ) # Replace attention with vLLM paged attention - self._replace_with_vllm_attention(model_args) + replace_with_vllm_attention(self.model) # Create ParallelDims from vLLM config and apply parallelization # NOTE: We need to apply parallelize within model.__init__ because w @@ -104,39 +104,6 @@ def __init__( else: logger.info("Single GPU mode - no parallelization needed") - def _replace_with_vllm_attention(self, model_args): - """ - Replace TorchTitan attention with vLLM paged attention. - - Assumes model has .layers dict with .attention.inner_attention structure. - Override in subclass if different structure. - """ - assert hasattr( - self.model, "layers" - ), f"Model {type(self.model).__name__} must have .layers attribute" - - for layer_name, layer in self.model.layers.items(): - assert hasattr( - layer, "attention" - ), f"Layer {layer_name} must have .attention attribute" - - vllm_attn = VLLMAttention( - hidden_size=model_args.dim, - num_heads=model_args.n_heads, - num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) - head_dim=model_args.head_dim, - layer_name=layer_name, - scale=model_args.head_dim**-0.5, - ) - - # Replace inner attention - layer.attention.inner_attention = vllm_attn - - logger.info( - f"Successfully replaced TorchTitan attention with VLLMAttention " - f"({len(self.model.layers)} layers)" - ) - def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int ) -> torch.Tensor: diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py new file mode 100644 index 0000000000..087e4f1e70 --- /dev/null +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multiprocess RL training loop using Monarch Actors. + +This demonstrates: +1. Distributed actor architecture with Generator (vLLM) and Trainer (TorchTitan) components +2. File based weight synchronization between trainer and generator + +The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. + +Command to run: +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +""" +import asyncio +import logging + +import torch +from monarch.actor import this_host +from monarch.utils import setup_env_for_distributed +from torchtitan.experiments.rl.unified.actors.generator import Generator +from torchtitan.experiments.rl.unified.actors.trainer import Trainer +from torchtitan.experiments.rl.unified.models.utils import ModelMode +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + download_and_convert_model, + load_gsm8k_dataset, +) +from vllm.model_executor.layers.batch_invariant import ( + init_batch_invariance, + vllm_is_batch_invariant, +) + +logger = logging.getLogger(__name__) + + +async def main(): + """Run the distributed RL training loop using Monarch.""" + # Model Config + model_name = "Qwen/Qwen3-1.7B" + cache_dir = "./models" + output_dir = "./converted" + + # Training config + group_size = 8 + num_steps = 10 + learning_rate = 1e-5 + max_new_tokens = 20 + + # GRPO config + use_stable_grpo = False + grpo_beta = 0.1 + + # Dataset config + use_real_dataset = False + num_dataset_samples = 5 + + # Parallelism sizes + trainer_ddp_size = 2 + trainer_tp_size = 1 + generator_tp_size = 1 + + init_batch_invariance() + batch_invariant = vllm_is_batch_invariant() + mode = ModelMode.VLLM_COMPAT + + # Set up batch invariant + if batch_invariant: + logger.info("Batch invariance detected - using vLLM-compatible model") + from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + ) + + enable_batch_invariant_backward_mode() + else: + raise RuntimeError("Batch invariance NOT detected - using standard model") + + # Download and convert model + titan_checkpoint_path, model_path = download_and_convert_model( + model_name, cache_dir, output_dir + ) + + # Load dataset + if use_real_dataset: + logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") + # TODO: Refactor into loading torchtitan dataset + prompt_texts, expected_answers = load_gsm8k_dataset( + split="train", num_samples=num_dataset_samples + ) + if prompt_texts is None or len(prompt_texts) == 0: + use_real_dataset = False + + if not use_real_dataset: + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] + + logger.info(f"Loaded {len(prompt_texts)} prompts") + + # Create process meshes + trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) + + # Spawn actors on trainer and generator mesh + trainer = trainer_mesh.spawn( + "trainer", + Trainer, + titan_checkpoint_path, + model_path, + learning_rate, + mode, + trainer_ddp_size, + trainer_tp_size, + ) + + # Set up distributed env vars so that titan actors are connected via c10d + await setup_env_for_distributed( + trainer_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29500, # TODO: figure out what to set + ) + + generator = gen_mesh.spawn( + "generator", + Generator, + model_path, + prompt_texts, + expected_answers, + group_size, + max_new_tokens, + 1.0, # temperature + use_real_dataset, + grpo_beta, + use_stable_grpo, + generator_tp_size, + ) + + # Initialize generator with trainer weights + initial_weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call_one(0, initial_weights) + + # Training loop + logger.info("\n" + "=" * 80) + logger.info(f"Starting RL training for {num_steps} steps") + logger.info("=" * 80) + + for step in range(num_steps): + # Fully sync RL loop + batch = await generator.generate.call_one() + metrics = await trainer.step.call(batch) + metrics = metrics.item(gpus=0) + weights = (await trainer.get_weights.call()).item(gpus=0) + await generator.update.call_one(metrics["policy_version"], weights) + + logger.info( + f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " + f"Reward: {metrics['reward_mean']:+.3f}" + ) + logger.info(f" Sample: {metrics['sample_completion']}...") + + # Check for divergence + if not torch.isfinite(torch.tensor(metrics["loss"])): + logger.info("\n" + "!" * 80) + logger.info("ERROR: Loss is NaN/Inf! Training diverged.") + logger.info("!" * 80) + break + + logger.info("\n" + "=" * 80) + logger.info("RL Training complete") + logger.info("=" * 80) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/torchtitan/experiments/rl/vllm_compat/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py index 508868c0d4..5e1fdd486b 100644 --- a/torchtitan/experiments/rl/vllm_compat/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -481,7 +481,6 @@ def load_gsm8k_dataset(split: str = "train", num_samples: int = 100): def trivial_reward_function( completions: list[str], - tokenizer=None, expected_answers: list[str] | None = None, group_size: int = 4, ) -> torch.Tensor: @@ -494,7 +493,6 @@ def trivial_reward_function( Args: completions: List of completion strings - tokenizer: Tokenizer to count tokens expected_answers: List of expected answers (one per prompt, repeated for group_size) group_size: Number of samples per prompt @@ -891,12 +889,7 @@ def rl_update_step( ) # Compute rewards using provided reward function - if reward_fn == trivial_reward_function: - rewards = reward_fn(completions, tokenizer, expected_answers, group_size) - elif reward_fn == math_reward_function: - rewards = reward_fn(completions, expected_answers, group_size) - else: - rewards = reward_fn(completions, expected_answers, group_size) + rewards = reward_fn(completions, expected_answers, group_size) # Normalize rewards for stability (mean=0, std=1) reward_mean = rewards.mean() From 29aafb91b7fbffe2ee259919a3249a0eb1d70779 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 22 Dec 2025 21:49:20 -0800 Subject: [PATCH 082/127] Fix qwen3 attention scaling calculation (#2173) as titiled, missing scale as part of attention input. --- torchtitan/models/qwen3/model/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 0683b4c42d..d7c88d4525 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -271,8 +271,11 @@ def forward( match self.attn_type: case "flex": assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + output = self.inner_attention( + xq, xk, xv, block_mask=attention_masks, scale=self.scaling + ) case "varlen": + # TODO: pass self.scaling into varlen attention assert isinstance(attention_masks, VarlenMetadata), attention_masks output = self.inner_attention( xq, @@ -283,7 +286,7 @@ def forward( ) case "sdpa": assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + output = self.inner_attention(xq, xk, xv, scale=self.scaling) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") From a452121c0bee5055a9077a2e535bcc154b1ca3c2 Mon Sep 17 00:00:00 2001 From: akashveramd Date: Tue, 23 Dec 2025 18:46:46 -0800 Subject: [PATCH 083/127] Add rocm support for models, flux & torchft integration tests. (#2172) In this PR, adding rocm support for models, flux & torchft integration tests. Also, enabled model_only_hf_checkpoint features test for ROCm. --- .../integration_test_8gpu_models.yaml | 42 +++++++++++++------ .../integration_test_8gpu_torchft.yaml | 38 ++++++++++++----- .github/workflows/set-matrix.yaml | 4 +- tests/integration_tests/features.py | 1 - 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_models.yaml b/.github/workflows/integration_test_8gpu_models.yaml index b673da5adf..acdbe9cb06 100644 --- a/.github/workflows/integration_test_8gpu_models.yaml +++ b/.github/workflows/integration_test_8gpu_models.yaml @@ -3,6 +3,8 @@ name: 8 GPU Model Tests on: push: branches: [ main ] + tags: + - ciflow/8gpu/* paths-ignore: - 'torchtitan/experiments/**' pull_request: @@ -21,18 +23,30 @@ defaults: run: shell: bash -l -eo pipefail {0} +permissions: + id-token: write + contents: read + jobs: + # Step 1: Dynamically compute the matrix based on conditions + set-matrix: + uses: ./.github/workflows/set-matrix.yaml + + # Step 2: Use the dynamic matrix in the build-test job build-test: + needs: set-matrix uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.set-matrix.outputs.matrix) }} with: - runner: linux.g5.48xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: "12.6" - # This image is faster to clone than the default, but it lacks CC needed by triton - # (1m25s vs 2m37s). - docker-image: torchtitan-ubuntu-20.04-clang12 + runner: ${{ matrix.runner }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + docker-image: ${{ matrix.docker-image }} repository: pytorch/torchtitan upload-artifact: outputs + timeout: 45 script: | set -eux @@ -46,12 +60,14 @@ jobs: pip config --user set global.progress_bar off - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + python -m pip install --force-reinstall --pre torch --index-url ${{ matrix.index-url }} + + USE_CPP=0 python -m pip install --pre torchao --index-url ${{ matrix.index-url }} - USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" + sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - mkdir artifacts-to-be-uploaded - python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8 - python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8 - rm -rf artifacts-to-be-uploaded/*/checkpoint - rm -rf artifacts-to-be-uploaded/flux/*/inference_results/ + python -m tests.integration_tests.run_tests --gpu_arch_type ${{ matrix.gpu-arch-type }} --test_suite models $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests.flux $RUNNER_TEMP/artifacts-to-be-uploaded/flux --ngpu 8 + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/flux/*/inference_results/ diff --git a/.github/workflows/integration_test_8gpu_torchft.yaml b/.github/workflows/integration_test_8gpu_torchft.yaml index 23f59d8bba..0931fa75b7 100644 --- a/.github/workflows/integration_test_8gpu_torchft.yaml +++ b/.github/workflows/integration_test_8gpu_torchft.yaml @@ -3,6 +3,8 @@ name: TorchFT 8 GPU Integration Test on: push: branches: [ main ] + tags: + - ciflow/8gpu/* paths: - 'torchtitan/components/ft.py' - '.github/workflows/integration_test_8gpu_torchft.yaml' @@ -21,18 +23,30 @@ defaults: run: shell: bash -l -eo pipefail {0} +permissions: + id-token: write + contents: read + jobs: + # Step 1: Dynamically compute the matrix based on conditions + set-matrix: + uses: ./.github/workflows/set-matrix.yaml + + # Step 2: Use the dynamic matrix in the build-test job build-test: + needs: set-matrix uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.set-matrix.outputs.matrix) }} with: - runner: linux.g5.48xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: "12.6" - # This image is faster to clone than the default, but it lacks CC needed by triton - # (1m25s vs 2m37s). - docker-image: torchtitan-ubuntu-20.04-clang12 + runner: ${{ matrix.runner }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + docker-image: ${{ matrix.docker-image }} repository: pytorch/torchtitan upload-artifact: outputs + timeout: 45 script: | set -eux @@ -47,14 +61,16 @@ jobs: pip config --user set global.progress_bar off python -m pip install torchft-nightly - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 - USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + python -m pip install --force-reinstall --pre torch --index-url ${{ matrix.index-url }} + USE_CPP=0 python -m pip install --pre torchao --index-url ${{ matrix.index-url }} + + sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" + sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - mkdir artifacts-to-be-uploaded echo "torchft_lighthouse" RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 & echo "ft_integration_test" # Getting error - Cuda failure 217 'peer access is not supported between these two devices' - python -m tests.integration_tests.ft artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests.ft $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # pkill -9 torchft_lighthouse - rm -rf artifacts-to-be-uploaded/*/checkpoint + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint diff --git a/.github/workflows/set-matrix.yaml b/.github/workflows/set-matrix.yaml index 5564d8d70b..3d6704641c 100644 --- a/.github/workflows/set-matrix.yaml +++ b/.github/workflows/set-matrix.yaml @@ -27,9 +27,9 @@ jobs: "name": "rocm", "runner": "linux.rocm.gpu.gfx942.8", "gpu-arch-type": "rocm", - "gpu-arch-version": "7.0", + "gpu-arch-version": "7.1", "docker-image": "torchtitan-rocm-ubuntu-22.04-clang12", - "index-url": "https://download.pytorch.org/whl/nightly/rocm7.0" + "index-url": "https://download.pytorch.org/whl/nightly/rocm7.1" }' # Define CUDA matrix diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 8e16ecb4fb..3662aa6bf6 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -121,7 +121,6 @@ def build_features_test_list() -> list[OverrideDefinitions]: ], "Checkpoint Integration Test - save load model only checkpoint in HF definition and format", "model_only_hf_checkpoint", - skip_rocm_test=True, ), OverrideDefinitions( [ From 30ab580cd6922173becb5b53f7dbe30afe636162 Mon Sep 17 00:00:00 2001 From: acisseJZhong <40467976+acisseJZhong@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:24:41 -0800 Subject: [PATCH 084/127] [RL] Support Trainer and Generator Unified Model (#2174) Support using `Qwen3TorchTitanForCausalLM` for both trainer and generator. For train, we replace attention with `VLLMCompatibleFlashAttention` since `VLLMAttention` doesn't have backward yet. For inference, we replace attention with `VLLMAttention`. When TP=1, train and inference has bitwise determinism. After we verify the speed for using unified model, we can delete the `VLLM_COMPAT` code path. Command ``` VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` --- .../rl/unified/actors/generator.py | 112 ++++++++++-------- .../experiments/rl/unified/models/utils.py | 35 +++++- .../rl/unified/models/vllm_wrapper.py | 2 +- .../rl/unified/simple_rl_multiprocess.py | 36 +++--- .../rl/vllm_compat/models/attention.py | 6 + 5 files changed, 123 insertions(+), 68 deletions(-) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index d0ee5cf38f..45d89095b7 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -14,6 +14,11 @@ import torch from monarch.actor import Actor, endpoint from safetensors.torch import save_file +from torchtitan.config.job_config import Comm +from torchtitan.distributed import utils as dist_utils + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.rl import unified # noqa: F401 from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_grpo_advantages, @@ -22,7 +27,6 @@ trivial_reward_function, ) from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm -from transformers import AutoTokenizer from vllm import LLM, SamplingParams logger = logging.getLogger(__name__) @@ -139,61 +143,70 @@ def update_weights(self, vllm_compat_state: dict) -> None: # TODO: need to replace this with Torchtitan's checkpoint save and load # right now we hardcoded to work with 2 safe tensor files which we only - # tested on Qwen3 1.7B model. In the longer term, need to use TorchStore + # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + # only generator rank 0 saves the weight + if torch.distributed.get_rank() == 0: + logger.info(f"Saving weights to {checkpoint_path}") + if len(shard_files) == 2 and os.path.exists(index_file): + # Load the index to see which weights go in which shard + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + + # Split weights according to the index + shard1_weights = {} + shard2_weights = {} + + for key, value in vllm_state.items(): + shard_file = weight_map.get(key, shard_files[0]) + if "model-00001-of-00002" in shard_file: + shard1_weights[key] = value + else: + shard2_weights[key] = value + + # Ensure weights stay in bfloat16 + shard1_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard1_weights.items() + } + shard2_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard2_weights.items() + } + + # Save to the shard files + save_file(shard1_weights, shard_files[0]) + save_file(shard2_weights, shard_files[1]) + else: + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) + + # Synchronize all ranks before reloading to ensure rank 0 finished writing + torch.distributed.barrier() + logger.info( + f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" + ) # First time: create the engine if self.llm is None: - # Disable distributed execution to avoid NCCL conflicts in Monarch actors - # Use single GPU mode - import os - - os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") self.llm = LLM( model=self.temp_model_dir, + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, trust_remote_code=True, max_model_len=2048, dtype="bfloat16", gpu_memory_utilization=0.1, # Reduced from 0.5 + distributed_executor_backend="external_launcher", # vllm do not spawn processes seed=42, # Fixed seed for determinism enforce_eager=True, tensor_parallel_size=self.tp_size, # Explicitly single GPU @@ -342,11 +355,12 @@ def __init__( self.use_stable_grpo = use_stable_grpo self.tp_size = tp_size + # Initialize distributed environment for SPMD generator + world_size = dist_utils.init_distributed( + Comm(), + ) # Initialize vLLM engine self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True - ) # State machine self.state = GeneratorState.READY_TO_UPDATE diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index f954cf16e5..36de483283 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -9,9 +9,12 @@ import torch from safetensors.torch import load_file - from torchtitan.experiments.rl.unified.models.attention import VLLMAttention +from torchtitan.experiments.rl.vllm_compat.models.attention import ( + VLLMCompatibleFlashAttention, +) + from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) @@ -40,7 +43,7 @@ class ModelMode(str, Enum): def replace_with_vllm_attention(model): """ - Replace TorchTitan attention with vLLM paged attention. + Replace TorchTitan attention with vLLM's Attention. Assumes model has .layers dict with .attention.inner_attention structure. """ @@ -71,6 +74,32 @@ def replace_with_vllm_attention(model): ) +def replace_with_vllm_compatible_flash_attention(model): + """ + Replace TorchTitan attention with vLLM compatible flash attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + """ + if not hasattr(model, "layers"): + raise AttributeError( + f"Model {type(model).__name__} must have .layers attribute" + ) + + model_args = model.model_args + for layer_name, layer in model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + vllm_attn = VLLMCompatibleFlashAttention() + + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(model.layers)} layers)" + ) + + def load_model( checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT ): @@ -121,7 +150,7 @@ def load_model( # Set global default dtype to bfloat16. This is needed because vLLM's Attention # layer uses torch.get_default_dtype() and it doesn't support float32 torch.set_default_dtype(torch.bfloat16) - replace_with_vllm_attention(model) + replace_with_vllm_compatible_flash_attention(model) # Load standard TorchTitan format directly model.load_state_dict(state_dict, strict=True) elif model_mode == ModelMode.VLLM_COMPAT: diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index 4faa9037dc..8dbc5ce393 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -82,7 +82,7 @@ def __init__( dim=self.config.head_dim, base=self.config.rope_theta, ) - # Replace attention with vLLM paged attention + # Replace attention with vLLM's attention replace_with_vllm_attention(self.model) # Create ParallelDims from vLLM config and apply parallelization diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 087e4f1e70..3e914f3778 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -40,7 +40,7 @@ async def main(): """Run the distributed RL training loop using Monarch.""" # Model Config - model_name = "Qwen/Qwen3-1.7B" + model_name = "Qwen/Qwen3-0.6B" cache_dir = "./models" output_dir = "./converted" @@ -65,7 +65,7 @@ async def main(): init_batch_invariance() batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.VLLM_COMPAT + mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -111,6 +111,20 @@ async def main(): trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + trainer_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29500, # TODO: figure out what to set + ) + + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + gen_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29501, # TODO: figure out what to set + ) + # Spawn actors on trainer and generator mesh trainer = trainer_mesh.spawn( "trainer", @@ -123,13 +137,6 @@ async def main(): trainer_tp_size, ) - # Set up distributed env vars so that titan actors are connected via c10d - await setup_env_for_distributed( - trainer_mesh, - master_addr="localhost", # TODO: figure out what to set - master_port=29500, # TODO: figure out what to set - ) - generator = gen_mesh.spawn( "generator", Generator, @@ -147,7 +154,7 @@ async def main(): # Initialize generator with trainer weights initial_weights = trainer.get_weights.call().get().item(gpus=0) - await generator.update.call_one(0, initial_weights) + await generator.update.call(0, initial_weights) # Training loop logger.info("\n" + "=" * 80) @@ -156,11 +163,10 @@ async def main(): for step in range(num_steps): # Fully sync RL loop - batch = await generator.generate.call_one() - metrics = await trainer.step.call(batch) - metrics = metrics.item(gpus=0) - weights = (await trainer.get_weights.call()).item(gpus=0) - await generator.update.call_one(metrics["policy_version"], weights) + batch = generator.generate.call().get().item(gpus=0) + metrics = trainer.step.call(batch).get().item(gpus=0) + weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call(metrics["policy_version"], weights) logger.info( f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " diff --git a/torchtitan/experiments/rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py index 11e6d3af67..3bcbe3071a 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +import math + import torch from vllm.attention.utils.fa_utils import flash_attn_varlen_func @@ -53,6 +55,10 @@ def forward( 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device ) + # Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. + if scale is None: + scale = 1.0 / math.sqrt(q.size(-1)) + # Wrap Flash Attention with manual backward pass class FlashAttnWithBackward(torch.autograd.Function): @staticmethod From a95d20385e254c4905bed8c40bf11d1f0a69f145 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 25 Dec 2025 20:48:34 -0800 Subject: [PATCH 085/127] Support TP when using vLLM engine to run inference w/ torchtitan model definition (#2165) As tiled. To support TP, now we keep a seperate TP plan for Qwen3 model. There are 2 main difference with torchtitan core Qwen3 TP plan: 1. Use all DTensor in TP region 2. Add PrepareModuleInputOuput annotation for innner_attention (vllm.Attention()) TODO: Add numerics check --- torchtitan/experiments/rl/unified/README.md | 4 +- torchtitan/experiments/rl/unified/__init__.py | 8 +- torchtitan/experiments/rl/unified/infer.py | 9 +- .../{models => infra}/parallelism_utils.py | 47 +++++- .../rl/unified/infra/parallelize.py | 155 ++++++++++++++++++ .../experiments/rl/unified/models/utils.py | 19 ++- .../rl/unified/models/vllm_wrapper.py | 81 ++++++--- 7 files changed, 280 insertions(+), 43 deletions(-) rename torchtitan/experiments/rl/unified/{models => infra}/parallelism_utils.py (64%) create mode 100644 torchtitan/experiments/rl/unified/infra/parallelize.py diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index fa54a936da..27550e977c 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -55,12 +55,12 @@ python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torch 4. Run inference: ``` -python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B +python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path ``` Run with TP: (work in progress) ``` -python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B --tensor-parallel-size 2 +python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path --tensor-parallel-size 2 ``` diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py index d5cfa6047d..430df3f268 100644 --- a/torchtitan/experiments/rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -11,12 +11,13 @@ Uses the canonical TorchTitan model definition directly with vLLM inference engine. """ +from torchtitan.experiments.rl.unified.infra.parallelize import parallelize_qwen3 from torchtitan.protocols.train_spec import get_train_spec, TrainSpec from vllm.logger import init_logger -from .models.parallelism_utils import create_parallel_dims_from_vllm_config -from .models.vllm_wrapper import TorchTitanVLLMModelWrapper +from .infra.parallelism_utils import create_parallel_dims_from_vllm_config +from .models.vllm_wrapper import TorchTitanVLLMModelWrapper logger = init_logger(__name__) @@ -57,7 +58,8 @@ def __init__(self, *, vllm_config, prefix=""): model_cls=train_spec.model_cls, model_args=model_args, state_dict_adapter=train_spec.state_dict_adapter, - parallelize_fn=train_spec.parallelize_fn, + # NOTE: This should be replaced with qwen3 parallelization plan in torchtitan core + parallelize_fn=parallelize_qwen3, vllm_config=vllm_config, prefix=prefix, ) diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 8c0b82edc0..43153fb70e 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -23,7 +23,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--model_ckpt_path", + "--model-ckpt-path", type=str, default="torchtitan/experiments/rl/example_checkpoint", help="Path to TorchTitan checkpoint directory", @@ -55,7 +55,7 @@ def parse_args(): return parser.parse_args() -def main(): +def infer(): args = parse_args() logger.info("Initializing vLLM with TorchTitan model") @@ -69,6 +69,8 @@ def main(): # 3. Create JobConfig and ParallelDims from vLLM config # 4. Apply parallelization using parallelize_qwen3 # 5. Load model weights and prepare for inference + # The tensor_parallel_size will be used by vLLM to configure parallelization + # and will be available in vllm_config in worker processes logger.info("Creating vLLM LLM engine...") llm = LLM( @@ -81,6 +83,7 @@ def main(): trust_remote_code=True, enforce_eager=True, # Use eager mode tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=0.5, ) logger.info("vLLM engine initialized successfully") @@ -112,4 +115,4 @@ def main(): if __name__ == "__main__": - main() + infer() diff --git a/torchtitan/experiments/rl/unified/models/parallelism_utils.py b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py similarity index 64% rename from torchtitan/experiments/rl/unified/models/parallelism_utils.py rename to torchtitan/experiments/rl/unified/infra/parallelism_utils.py index bac266149d..cc57b5a85f 100644 --- a/torchtitan/experiments/rl/unified/models/parallelism_utils.py +++ b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py @@ -11,8 +11,9 @@ tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. """ + import torch.distributed as dist -from torchtitan.config.job_config import Comm +from torchtitan.config.job_config import Comm, JobConfig, Model, Parallelism, Training from torchtitan.distributed import utils as dist_utils from torchtitan.distributed.parallel_dims import ParallelDims @@ -92,3 +93,47 @@ def create_trainer_parallel_dims(ddp_size, tp_size) -> ParallelDims: etp=1, world_size=world_size, ) + + +def create_job_config_from_vllm_config( + vllm_config: VllmConfig, + model_name: str = "qwen3", + hf_assets_path: str = "/path/to/hf/assets", +) -> JobConfig: + """ + Create TorchTitan JobConfig from vLLM configuration. + + Args: + vllm_config: vLLM configuration object containing model, parallel, and cache configs + model_name: Model name to use (default: "qwen3") + hf_assets_path: Path to HuggingFace assets directory (default: "/path/to/hf/assets") + + Returns: + JobConfig object with settings mapped from vLLM config + """ + # Create JobConfig with defaults + job_config = JobConfig() + + model_config = vllm_config.model_config + job_config.model = Model( + name=model_name, + hf_assets_path=hf_assets_path, + ) + + parallel_config = vllm_config.parallel_config + job_config.parallelism = Parallelism( + data_parallel_replicate_degree=parallel_config.data_parallel_size, + data_parallel_shard_degree=1, # vLLM doesn't use FSDP sharding in inference + context_parallel_degree=parallel_config.decode_context_parallel_size, + tensor_parallel_degree=parallel_config.tensor_parallel_size, + pipeline_parallel_degree=parallel_config.pipeline_parallel_size, + expert_parallel_degree=1, # Not used in vLLM inference yet + expert_tensor_parallel_degree=1, # Not used in vLLM inference yet + ) + + job_config.training = Training( + local_batch_size=1, # Inference typically processes one batch at a time + steps=1, # Single step for inference + ) + + return job_config diff --git a/torchtitan/experiments/rl/unified/infra/parallelize.py b/torchtitan/experiments/rl/unified/infra/parallelize.py new file mode 100644 index 0000000000..8cbeeed783 --- /dev/null +++ b/torchtitan/experiments/rl/unified/infra/parallelize.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + + +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims + + +def parallelize_qwen3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Temporary helper to apply tensor parallelism to the Qwen3 dense model so vLLM can run the torchtitan model. + """ + + if parallel_dims.tp_enabled: + tp_mesh = parallel_dims.get_mesh("tp") + apply_non_moe_tp( + model, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism to the Qwen3 dense model. + + This is a temporary TP plan used while we resolve composability issues in the + main torchtitan codebase. Once DTensor is fully supported across the TP + region, this separate plan should be removed. + """ + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=False, + ), + "norm": SequenceParallel( + use_local_output=False, + ), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + use_local_output=True, # return logits and plain tensor + ), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel( + use_local_output=False, + ), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() + "attention": PrepareModuleInput( + input_layouts=(Shard(1), Replicate(), None, Replicate()), + desired_input_layouts=(Replicate(), Replicate(), None, Replicate()), + ), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), + "attention.q_norm": SequenceParallel( + sequence_dim=2, + use_local_output=False, + ), + "attention.k_norm": SequenceParallel( + sequence_dim=2, + use_local_output=False, + ), + # Apply on vllm.Attention() module to use local tensor + "attention.inner_attention": PrepareModuleInputOutput( + input_layouts=(Shard(1), Shard(1), Shard(1)), # xq, xk, xv + desired_input_layouts=(None, None, None), + use_local_input=True, # use local tensor for attention calculation + output_layouts=(Shard(1)), # output + desired_output_layouts=(Shard(1)), + use_local_output=False, + ), + "attention.wo": RowwiseParallel( + output_layouts=Shard(1), + use_local_output=False, + ), + "ffn_norm": SequenceParallel( + use_local_output=False, + ), + } + + # pyrefly: ignore [missing-attribute] + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(use_local_output=False), + "feed_forward.w2": RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + "feed_forward.w3": ColwiseParallel(use_local_output=False), + } + ) + else: + raise ValueError( + "Running vLLM inference with torchtitan Qwen3 MoE model is not supported yet." + ) + + parallelize_module( + # pyrefly: ignore [bad-argument-type] + module=transformer_block, + device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] + parallelize_plan=layer_plan, + ) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 36de483283..0e5d6cde52 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import logging from enum import Enum import torch from safetensors.torch import load_file -from torchtitan.experiments.rl.unified.models.attention import VLLMAttention +from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) @@ -41,7 +42,7 @@ class ModelMode(str, Enum): STANDARD = "standard" -def replace_with_vllm_attention(model): +def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -59,8 +60,9 @@ def replace_with_vllm_attention(model): vllm_attn = VLLMAttention( hidden_size=model_args.dim, - num_heads=model_args.n_heads, - num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) + num_heads=model_args.n_heads // tp_degree, + num_kv_heads=model_args.n_heads + // tp_degree, # Use n_heads (already replicated) head_dim=model_args.head_dim, layer_name=layer_name, scale=model_args.head_dim**-0.5, @@ -95,7 +97,7 @@ def replace_with_vllm_compatible_flash_attention(model): layer.attention.inner_attention = vllm_attn logger.info( - f"Successfully replaced TorchTitan attention with VLLMAttention " + f"Successfully replaced TorchTitan attention with VLLMCompatibleFlashAttention " f"({len(model.layers)} layers)" ) @@ -104,7 +106,7 @@ def load_model( checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT ): """ - Load TorchTitan model from checkpoint. + Load TorchTitan model from checkpoint for trainer. Args: checkpoint_path: Path to TorchTitan checkpoint @@ -113,7 +115,7 @@ def load_model( or plain Torchtitan model Returns: - model: Loaded TorchTitan model + model: Loaded TorchTitan model for trainer. """ # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly @@ -150,7 +152,10 @@ def load_model( # Set global default dtype to bfloat16. This is needed because vLLM's Attention # layer uses torch.get_default_dtype() and it doesn't support float32 torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention for training. replace_with_vllm_compatible_flash_attention(model) + # Load standard TorchTitan format directly model.load_state_dict(state_dict, strict=True) elif model_mode == ModelMode.VLLM_COMPAT: diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index 8dbc5ce393..f3ae7f348a 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -21,6 +21,11 @@ StateDictOptions, ) +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( + create_job_config_from_vllm_config, + create_parallel_dims_from_vllm_config, +) + from torchtitan.experiments.rl.unified.models.utils import replace_with_vllm_attention from torchtitan.models.qwen3.model.model import precompute_rope_cache from torchtitan.protocols.model import BaseModelArgs, ModelProtocol @@ -30,20 +35,21 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from .parallelism_utils import create_parallel_dims_from_vllm_config - logger = init_logger(__name__) class TorchTitanVLLMModelWrapper(nn.Module): """ - Generic vLLM-compatible model wrapper for TorchTitan models. + Generic vLLM-compatible model wrapper for TorchTitan models. Implemented + required interface required by vLLM Engine. + Doc: https://docs.vllm.ai/en/latest/contributing/model/basic/ + Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py The wrapper handles: - Direct usage of TorchTitan model args (no HF config mapping needed) - Attention replacement with vLLM paged attention - - Tensor parallelism setup + - Parallelism setup and DTensor conversion between torchtitan and vLLM - Weight loading from HF checkpoints - vLLM forward/compute_logits interface """ @@ -82,27 +88,31 @@ def __init__( dim=self.config.head_dim, base=self.config.rope_theta, ) - # Replace attention with vLLM's attention - replace_with_vllm_attention(self.model) - - # Create ParallelDims from vLLM config and apply parallelization - # NOTE: We need to apply parallelize within model.__init__ because w - parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) - if parallel_dims.tp_enabled: - self.world_mesh = parallel_dims.world_mesh - tp_mesh = self.world_mesh["tp"] - parallelize_fn( - model=self.model, - tp_mesh=tp_mesh, - loss_parallel=False, - enable_float8_tensorwise_tp=False, - enable_async_tp=False, - ) - logger.info( - f"Successfully initialized model with with TP={parallel_dims.tp}" - ) - else: - logger.info("Single GPU mode - no parallelization needed") + + # Create ParallelDims and JobConfig from vLLM config at runtime + # vLLM config contains the tensor_parallel_size from command-line args + # and this will be consistent across all worker processes + self.parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + self.parallel_config = create_job_config_from_vllm_config( + vllm_config=vllm_config, + ) + # Replace attention with vLLM paged attention + tp_size = self.parallel_dims.tp + if tp_size > 1: + assert ( + model_args.n_heads % tp_size == 0 + ), "Only support when n_heads can be divided by tp_size" + + replace_with_vllm_attention(self.model, tp_degree=tp_size) + + # NOTE: We need to apply parallelize within model.__init__ because vllm + # doesn't separate model creation and parallelism application and instead + # requires parallelization to be done inside model constructor. + self.model = parallelize_fn( + model=self.model, + parallel_dims=self.parallel_dims, + job_config=self.parallel_config, + ) def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int @@ -117,8 +127,6 @@ def _extend_rope_cache_if_needed( Returns: Extended RoPE cache if needed, otherwise original cache """ - from torch.distributed._tensor import DTensor, Replicate - required_len = max_position + 1 # No extension needed @@ -230,6 +238,12 @@ def forward( for layer in self.model.layers.values(): h = layer(h, rope_cache, attention_masks=None, positions=positions) + # When parallelism is applied, get full tensor before return to vLLM Engine + # The original placement is Shard(1) (shard on sequence dimension, as it will prepare for sequence parallel in `self.norm`). + # vLLM’s engine expects plain, non-distributed tensors to slice the last token for each request. + if isinstance(h, DTensor): + h = h.full_tensor() + # Convert to vLLM format: [total_tokens, hidden_size] if h.dim() == 3: batch_size, seq_len, hidden_size = h.shape @@ -243,6 +257,19 @@ def compute_logits( sampling_metadata=None, ) -> torch.Tensor | None: """Compute logits from hidden states.""" + + # When TP is applied, we return the full tensor (plain tensor) to vLLM engine + # at the end of TorchTitanVLLMModelWrapper.forward(). + # We need to wrap the input from vLLM engine back to DTensor with Replicate() placement. + if self.parallel_dims.tp_enabled: + hidden_states = DTensor.from_local( + hidden_states, + device_mesh=self.parallel_dims.get_mesh("tp"), + placements=[ + Replicate(), + ], + ) + h = self.model.norm(hidden_states) logits = self.model.output(h) From 5077be6ef889de93fc7a246197e33967d4307531 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Fri, 26 Dec 2025 16:01:49 -0500 Subject: [PATCH 086/127] add safety checks for varlen (#2179) since varlen attn is not supported for deepseek v3 and llama4, we should raise an error if someone attempts to use it --- torchtitan/models/deepseek_v3/model/model.py | 6 +++++- torchtitan/models/llama3/model/model.py | 4 +++- torchtitan/models/llama4/model/model.py | 6 +++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index fdc6ef56a7..fc846d4098 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -237,9 +237,13 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): match self.attn_type: case "flex": self.inner_attention = FlexAttentionWrapper() - case _: + case "sdpa": # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() + case "varlen": + raise ValueError("Varlen attention is not supported with Deepseek V3.") + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def forward( self, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index cafd58a52e..40767c06d2 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -225,9 +225,11 @@ def __init__(self, model_args: TransformerModelArgs): case "varlen": # pyrefly: ignore [bad-assignment] self.inner_attention = VarlenAttentionWrapper() - case _: + case "sdpa": # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index e08f733f28..8ee54cae5b 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -230,9 +230,13 @@ def __init__( match self.attn_type: case "flex": self.inner_attention = FlexAttentionWrapper() - case _: + case "sdpa": # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() + case "varlen": + raise ValueError("Varlen attention is not supported with Llama 4.") + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): From 64b5e15fadeb6f71c43269d5ab69dc93db5aa0bc Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 26 Dec 2025 13:47:39 -0800 Subject: [PATCH 087/127] Bump torchtitan version to v0.2.1 (#2180) as titiled --- assets/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assets/version.txt b/assets/version.txt index 0ea3a944b3..0c62199f16 100644 --- a/assets/version.txt +++ b/assets/version.txt @@ -1 +1 @@ -0.2.0 +0.2.1 From 81af8833ddeff9b5f1874dc7e20594aa17da6b86 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 26 Dec 2025 15:27:56 -0800 Subject: [PATCH 088/127] Remove psutil as part of requirements (#2181) --- .ci/docker/requirements.txt | 1 - pyproject.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index b63653bb53..5925bfd1d3 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -8,6 +8,5 @@ fsspec tyro tokenizers >= 0.15.0 safetensors -psutil einops pillow diff --git a/pyproject.toml b/pyproject.toml index aa5a93fd7c..7be33b0f7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "fsspec", "tyro", "tensorboard", - "psutil", "einops", "pillow", ] From 5dd9f4ca0946aa973a0798ddda9d284de4011684 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Mon, 29 Dec 2025 13:22:42 -0500 Subject: [PATCH 089/127] add attention scaling to varlen for qwen3 (#2178) as title, fixes [#2170](https://github.com/pytorch/torchtitan/issues/2170) --- torchtitan/models/attention.py | 2 ++ torchtitan/models/qwen3/model/model.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index b04a6a136e..e1255f1e94 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -60,6 +60,7 @@ def forward( xv: torch.Tensor, head_dim: torch.Tensor, attention_masks: VarlenMetadata, + scale: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: cu_seq_q = attention_masks.cu_seq_q cu_seq_k = attention_masks.cu_seq_k @@ -83,6 +84,7 @@ def forward( max_q, max_k, is_causal=True, + scale=scale, ) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index d7c88d4525..389f0fc9f7 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -283,6 +283,7 @@ def forward( xv, self.head_dim, attention_masks, + scale=self.scaling, ) case "sdpa": assert attention_masks is None From 62f5806ed48fde29a4a44cb8edd569fc4b200514 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 29 Dec 2025 16:13:54 -0700 Subject: [PATCH 090/127] make get tp mesh optional in llama4 parallelize (#2185) Fixes #2184 get_mesh() throws an exception if the mesh dim is none, so the current code implicitly requires TP > 1. we should use get_optional_mesh() so the user isn't required to use TP. --- torchtitan/models/llama4/infra/parallelize.py | 5 +---- torchtitan/models/qwen3/infra/parallelize.py | 3 +-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 454679ff55..154c48d992 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -133,12 +133,9 @@ def parallelize_llama( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) - # tp_mesh might have been set above if tp_enabled, otherwise get it here - if tp_mesh is None: - tp_mesh = parallel_dims.get_mesh("tp") apply_moe_ep_tp( model, - tp_mesh=tp_mesh, + tp_mesh=parallel_dims.get_optional_mesh("tp"), ep_mesh=parallel_dims.get_optional_mesh("ep"), etp_mesh=parallel_dims.get_optional_mesh("etp"), ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4c7f43f426..28a3ba3304 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -102,10 +102,9 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) - tp_mesh = parallel_dims.get_mesh("tp") apply_moe_ep_tp( model, - tp_mesh=tp_mesh, + tp_mesh=parallel_dims.get_optional_mesh("tp"), ep_mesh=parallel_dims.get_optional_mesh("ep"), etp_mesh=parallel_dims.get_optional_mesh("etp"), ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), From 7e4ab85998576c68902603058adada28fb0ed226 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 29 Dec 2025 17:34:49 -0800 Subject: [PATCH 091/127] Add docs to explain COMM_MODE (#2162) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2162 As title --- docs/debugging.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++ run_train.sh | 18 +++++++++++--- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 4deb20bbac..7a14606b51 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -54,6 +54,69 @@ python -m torchtitan.config.manager --help This will print a structured configuration to `stdout`, allowing you to verify that overrides are being applied correctly. +## Communication Mode (COMM_MODE) for Debugging + +The `COMM_MODE` environment variable provides specialized debugging modes that allow you to test and validate your training setup without requiring full multi-GPU distributed execution. This is particularly useful for rapid iteration during development and debugging. + +### Available Modes + +#### 1. `fake_backend` - Configuration Validation Mode + +This mode enables dry-run validation of your configuration, model setup, and rank-0 program logic without actual distributed communication: + +```bash +NGPU=32 COMM_MODE="fake_backend" ./run_train.sh +``` + +**What it does:** +- Uses fake process groups that simulate distributed communication without actual data transfer +- Runs on a single GPU without `torchrun` or NCCL initialization +- Validates configuration parsing, model initialization, and overall training workflow +- Executes only one training step by default + +**When to use it:** +- Quick validation of configuration files before launching expensive multi-GPU jobs +- Debugging training and parallelism logic that doesn't require actual communication. Note that No data-dependent logic should be validated with "fake_backend". + +**Example use case:** +```bash +# Validate a 128-GPU configuration on a single GPU +NGPU=128 COMM_MODE="fake_backend" CONFIG_FILE="./train_configs/llama3_70b.toml" ./run_train.sh +``` + +#### 2. `local_tensor` - Single-GPU Distributed Simulation + +This mode simulates the full distributed training workflow on a single GPU by executing all communication and computation locally: + +```bash +NGPU=32 COMM_MODE="local_tensor" ./run_train.sh +``` + +**What it does:** +- Simulates multi-GPU behavior on a single shared GPU +- Executes all collectives (all-reduce, all-gather, etc.) locally without network communication +- Maintains the same code paths as distributed training for accurate debugging +- Runs only one training step by default + +**When to use it:** +- Debugging distributed training logic (FSDP, TP, PP, CP, EP) with data dependencies without multi-GPU setup. Note that local tensor doesn't support FSDP2 but should support SimpleFSDP. +- Verifying correctness of parallelism strategies locally +- Testing gradient synchronization and communication patterns +- Reproducing distributed training bugs in a simplified environment + +**Example use case:** +```bash +# Debug 8-way TP + 2-way FSDP on a single GPU +NGPU=16 COMM_MODE="local_tensor" ./run_train.sh \ + --parallelism.tensor_parallel_degree 8 \ + --parallelism.data_parallel_shard_degree 2 +``` + +### Limitations + +- **Performance testing**: Neither mode provides accurate performance metrics; use actual distributed runs for benchmarking +- **Memory requirement**: Local tensor runs require more memory on a single GPU than the actual distributed runs + ## Troubleshooting jobs that timeout If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs. diff --git a/run_train.sh b/run_train.sh index 87558a782d..069ea084b0 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,13 +10,25 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU -# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode +# +# COMM_MODE options for debugging: +# +# 1. "fake_backend" - Dry-run mode for config validation without GPU execution +# - Uses fake process groups (no actual communication) +# - Runs on a single GPU without torchrun or NCCL initialization +# - Useful for validating configuration and model setup +# Example: NGPU=32 COMM_MODE="fake_backend" ./run_train.sh +# +# 2. "local_tensor" - Single-GPU debugging mode with simulated multi-GPU behavior +# - All communication and computation execute on a single shared GPU +# - Simulates the full training workflow without actual distributed communication +# - Useful for debugging distributed training logic locally +# Example: NGPU=32 COMM_MODE="local_tensor" ./run_train.sh + NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} -# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training COMM_MODE=${COMM_MODE:-""} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} From e16af85901aa7bf8042399afe6c0fcebd43c8beb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Brya=CE=B7=20C=2E?= <101939095+BryanBradfo@users.noreply.github.com> Date: Tue, 6 Jan 2026 17:57:58 +0100 Subject: [PATCH 092/127] [docs] Fix missing --model.flavor flags in compiler_toolkit README (#2168) (#2201) Adds the missing `--model.flavor=debugmodel_flex_attn` flag to Llama 3 FlexAttention commands in `compiler_toolkit/README.md`. Fixes #2168. cc @yiming0416 --- torchtitan/experiments/compiler_toolkit/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 620911ce60..9c1833b55f 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -47,17 +47,17 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to **SimpleFSDP + TP + FlexAttention + auto-bucketing + regional-inductor** ```shell -NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor --model.flavor=debugmodel_flex_attn ``` **SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor** ```shell -NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor --model.flavor=debugmodel_flex_attn ``` **SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** ```shell -NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph --model.flavor=debugmodel_flex_attn ``` From 795a7a027eccf282c51a5ae1cc0c1c3459120c9b Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Tue, 6 Jan 2026 09:56:23 -0800 Subject: [PATCH 093/127] [GPT-OSS] Graduate from experiments to main (#2203) As titled. Unit test with fsdp=4, tp=2, ep=4, passed locally. Pasted Graphic 1 --- tests/integration_tests/models.py | 16 ++++++++++++++ torchtitan/experiments/__init__.py | 1 - torchtitan/models/__init__.py | 2 +- .../{experiments => models}/gpt_oss/README.md | 3 --- .../gpt_oss/__init__.py | 0 .../gpt_oss/infra/expert_parallel.py | 4 ++++ .../gpt_oss/infra/parallelize.py | 21 ++++++++++++++++++- .../gpt_oss/model/args.py | 1 + .../gpt_oss/model/model.py | 3 +++ .../gpt_oss/model/moe.py | 13 ++++++++++++ .../gpt_oss/model/state_dict_adapter.py | 2 ++ .../gpt_oss/train_configs/debug_model.toml | 0 12 files changed, 60 insertions(+), 6 deletions(-) rename torchtitan/{experiments => models}/gpt_oss/README.md (75%) rename torchtitan/{experiments => models}/gpt_oss/__init__.py (100%) rename torchtitan/{experiments => models}/gpt_oss/infra/expert_parallel.py (91%) rename torchtitan/{experiments => models}/gpt_oss/infra/parallelize.py (91%) rename torchtitan/{experiments => models}/gpt_oss/model/args.py (99%) rename torchtitan/{experiments => models}/gpt_oss/model/model.py (98%) rename torchtitan/{experiments => models}/gpt_oss/model/moe.py (94%) rename torchtitan/{experiments => models}/gpt_oss/model/state_dict_adapter.py (98%) rename torchtitan/{experiments => models}/gpt_oss/train_configs/debug_model.toml (100%) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 606ecfe4bd..665c9ef44e 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -125,6 +125,22 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "llama4_pp+fsdp+tp+ep+compile", ngpu=8, ), + # Integration Test Cases for gpt-oss + OverrideDefinitions( + [ + [ + "--model.name gpt_oss", + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--compile.enable", + ], + ], + "Gpt-oss FSDP+TP+EP+compile", + "gpt_oss_fsdp+tp+ep+compile", + ngpu=8, + ), ] return model_tests diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 10f9030c1d..5989025d4f 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -6,7 +6,6 @@ _supported_experiments = frozenset( [ - "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm", diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index f4afa71d97..f372d29461 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -5,5 +5,5 @@ # LICENSE file in the root directory of this source tree. _supported_models = frozenset( - ["deepseek_v3", "flux", "llama3", "llama3_ft", "llama4", "qwen3"] + ["deepseek_v3", "flux", "gpt_oss", "llama3", "llama3_ft", "llama4", "qwen3"] ) diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/models/gpt_oss/README.md similarity index 75% rename from torchtitan/experiments/gpt_oss/README.md rename to torchtitan/models/gpt_oss/README.md index a8283ab7b6..c16898bd80 100644 --- a/torchtitan/experiments/gpt_oss/README.md +++ b/torchtitan/models/gpt_oss/README.md @@ -12,6 +12,3 @@ CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./ ## TODO 1. More parallelism support: CP, PP -2. Conversion between HF weights (StateDictAdapter) -3. Forward parity verification -4. CI support diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/models/gpt_oss/__init__.py similarity index 100% rename from torchtitan/experiments/gpt_oss/__init__.py rename to torchtitan/models/gpt_oss/__init__.py diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/models/gpt_oss/infra/expert_parallel.py similarity index 91% rename from torchtitan/experiments/gpt_oss/infra/expert_parallel.py rename to torchtitan/models/gpt_oss/infra/expert_parallel.py index 1e8054a481..33de706648 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/models/gpt_oss/infra/expert_parallel.py @@ -43,24 +43,28 @@ def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> N mod.register_parameter( "mlp1_weight", nn.Parameter( + # pyrefly: ignore [bad-argument-type] distribute_tensor(mod.mlp1_weight, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp1_bias", nn.Parameter( + # pyrefly: ignore [bad-argument-type] distribute_tensor(mod.mlp1_bias, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp2_weight", nn.Parameter( + # pyrefly: ignore [bad-argument-type] distribute_tensor(mod.mlp2_weight, device_mesh, [Shard(0), Shard(2)]) ), ) # Row-wise sharding mod.register_parameter( "mlp2_bias", nn.Parameter( + # pyrefly: ignore [bad-argument-type] distribute_tensor(mod.mlp2_bias, device_mesh, [Shard(0), Replicate()]) ), ) # Replicate diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py similarity index 91% rename from torchtitan/experiments/gpt_oss/infra/parallelize.py rename to torchtitan/models/gpt_oss/infra/parallelize.py index 4768dab659..f97d15369d 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -52,6 +52,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + # pyrefly: ignore [missing-attribute] torch._higher_order_ops.inductor_compiled_code, } @@ -106,7 +107,7 @@ def parallelize_gptoss( model, tp_mesh=parallel_dims.get_optional_mesh("tp"), ep_mesh=parallel_dims.get_optional_mesh("ep"), - ep_etp_mesh=parallel_dims.get_optional_mesh("ep_etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), etp_enabled=parallel_dims.etp_enabled, dual_pipe_v=dual_pipe_v, ) @@ -116,6 +117,7 @@ def parallelize_gptoss( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, ) @@ -200,21 +202,28 @@ def apply_non_moe_tp( ) # Apply tensor + sequence parallelism to every transformer block + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( + # pyrefly: ignore [bad-argument-type] input_layouts=(Shard(1), Replicate(), None), + # pyrefly: ignore [bad-argument-type] desired_input_layouts=(Replicate(), Replicate(), None), ), "attention.wq": ColwiseParallel(use_local_output=False), "attention.wk": ColwiseParallel(use_local_output=False), "attention.wv": ColwiseParallel(use_local_output=False), "attention.inner_attention": PrepareModuleInputOutput( + # pyrefly: ignore [bad-argument-type] input_layouts=(Shard(1), Shard(1), Shard(1)), + # pyrefly: ignore [bad-argument-type] desired_input_layouts=(Shard(1), Shard(1), Shard(1)), use_local_input=True, + # pyrefly: ignore [bad-argument-type] output_layouts=(Shard(1), Shard(1)), + # pyrefly: ignore [bad-argument-type] desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False, ), @@ -223,6 +232,7 @@ def apply_non_moe_tp( } # shard attention.sinks across heads + # pyrefly: ignore [missing-attribute] attn = transformer_block.attention attn.register_parameter( "sinks", @@ -230,14 +240,17 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + # pyrefly: ignore [implicit-import] torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) @@ -257,7 +270,9 @@ def apply_moe_ep_tp( ): assert ep_mesh is not None or tp_mesh is not None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue @@ -279,11 +294,14 @@ def apply_moe_ep_tp( # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=moe_layer_plan, ) @@ -304,6 +322,7 @@ def apply_moe_ep_tp( experts_plan = DualPipeExpertParallel(experts_plan) parallelize_module( + # pyrefly: ignore [missing-attribute] module=transformer_block.moe.experts, device_mesh=experts_mesh, parallelize_plan=experts_plan, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/models/gpt_oss/model/args.py similarity index 99% rename from torchtitan/experiments/gpt_oss/model/args.py rename to torchtitan/models/gpt_oss/model/args.py index af4c51eadc..2a9aa970e4 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/models/gpt_oss/model/args.py @@ -91,6 +91,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for gpt-oss model is still in progress." ) + # pyrefly: ignore [bad-override] def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/models/gpt_oss/model/model.py similarity index 98% rename from torchtitan/experiments/gpt_oss/model/model.py rename to torchtitan/models/gpt_oss/model/model.py index 1fcd12eaa9..9db5818b3f 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/models/gpt_oss/model/model.py @@ -250,8 +250,10 @@ def forward( """ # Extract the appropriate mask for this layer if self.use_sliding_attention: + # pyrefly: ignore [missing-attribute] layer_mask = attention_masks.get("sliding_window_mask", None) else: + # pyrefly: ignore [missing-attribute] layer_mask = attention_masks.get("basic_mask", None) assert layer_mask is not None @@ -304,6 +306,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/models/gpt_oss/model/moe.py similarity index 94% rename from torchtitan/experiments/gpt_oss/model/moe.py rename to torchtitan/models/gpt_oss/model/moe.py index 94cd266761..2a60751ec3 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/models/gpt_oss/model/moe.py @@ -25,6 +25,7 @@ class ScaleBiasForward(torch.autograd.Function): """ @staticmethod + # pyrefly: ignore [bad-override] def forward(ctx, bias, tp_degree): ctx.tp_degree = tp_degree if tp_degree > 1: @@ -32,6 +33,7 @@ def forward(ctx, bias, tp_degree): return bias @staticmethod + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): # Don't scale the gradient - pass it through as-is return grad_output, None @@ -101,6 +103,7 @@ def _run_experts_for_loop( tp_degree: int = 1, ) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host + # pyrefly: ignore [bad-assignment] num_tokens_per_expert = num_tokens_per_expert.tolist() # side-effect code due to the usage of generate_permute_indices @@ -108,8 +111,10 @@ def _run_experts_for_loop( # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) + # pyrefly: ignore [bad-assignment] x = torch.split( x[: sum(num_tokens_per_expert)], + # pyrefly: ignore [bad-argument-type] split_size_or_sections=num_tokens_per_expert, dim=0, ) @@ -127,6 +132,7 @@ def _run_experts_for_loop( out = torch.cat(out_experts_splits, dim=0) # side-effect code due to the usage of generate_permute_indices + # pyrefly: ignore [no-matching-overload] out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) return out @@ -201,8 +207,11 @@ def forward( # Convert parameters from DTensors to plain Tensors, to work with # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. mlp1_weight = self.mlp1_weight.to_local() + # pyrefly: ignore [missing-attribute] mlp1_bias = self.mlp1_bias.to_local() + # pyrefly: ignore [missing-attribute] mlp2_weight = self.mlp2_weight.to_local() + # pyrefly: ignore [missing-attribute] mlp2_bias = self.mlp2_bias.to_local() else: mlp1_weight = self.mlp1_weight @@ -214,13 +223,16 @@ def forward( tp_degree = 1 if isinstance(self.mlp1_weight, DTensor): mesh_dim_names = self.mlp1_weight.device_mesh.mesh_dim_names + # pyrefly: ignore [not-iterable] if "tp" in mesh_dim_names: + # pyrefly: ignore [missing-attribute] tp_dim_idx = mesh_dim_names.index("tp") tp_degree = self.mlp1_weight.device_mesh.size(tp_dim_idx) if self.use_grouped_mm: if ( not isinstance(self.mlp1_weight, DTensor) + # pyrefly: ignore [not-iterable] or "ep" not in self.mlp1_weight.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) @@ -266,6 +278,7 @@ def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): super().__init__(moe_args, dim, hidden_dim) # Override the base GroupedExperts with GptOssGroupedExperts + # pyrefly: ignore [bad-assignment] self.experts = GptOssGroupedExperts( dim=dim, hidden_dim=hidden_dim, diff --git a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py b/torchtitan/models/gpt_oss/model/state_dict_adapter.py similarity index 98% rename from torchtitan/experiments/gpt_oss/model/state_dict_adapter.py rename to torchtitan/models/gpt_oss/model/state_dict_adapter.py index ca85789baf..9198505257 100644 --- a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py +++ b/torchtitan/models/gpt_oss/model/state_dict_adapter.py @@ -82,6 +82,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore layer_num = re.search(r"\d+", key).group(0) hf_key = to_hf_map[abstract_key] hf_key = hf_key.format(layer_num) @@ -103,6 +104,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: + # pyrefly: ignore layer_num = re.search(r"\d+", key).group(0) abstract_key = re.sub(r"(\d+)", "{}", key, count=1) tt_key = self.from_hf_map[abstract_key] diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/models/gpt_oss/train_configs/debug_model.toml similarity index 100% rename from torchtitan/experiments/gpt_oss/train_configs/debug_model.toml rename to torchtitan/models/gpt_oss/train_configs/debug_model.toml From 9f211ec199bc887901b874edd6af5a20527a4175 Mon Sep 17 00:00:00 2001 From: Aditya Venkataraman Date: Wed, 7 Jan 2026 08:45:04 -0800 Subject: [PATCH 094/127] [Compiler Toolkit] Add option for full inductor. (#2150) Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines. Full inductor compilation is achieved using `compile_fx_inner`, however, it requires the graph to have been decomposed using Inductor's default decomposition table. We apply this decomposition as a pass on the joint graph. We need to be careful to suitably unwrap the primals/tangents before running this decomposition. Manual testing: NGPU=4 \ CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name $MODEL_NAME \ --parallelism.data_parallel_shard_degree=2 \ --parallelism.tensor_parallel_degree=2 \ --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.joint_passes inductor_decomposition \ --compile.passes full_inductor_compilation --- .../experiments/compiler_toolkit/README.md | 6 + .../deepseek_v3/parallelize.py | 3 +- .../compiler_toolkit/graph_utils.py | 139 +++++++++++++-- .../compiler_toolkit/job_config.py | 17 +- .../compiler_toolkit/llama3/parallelize.py | 3 +- .../experiments/compiler_toolkit/passes.py | 164 ++++++++++++++++++ .../tests/integration_tests.py | 16 +- 7 files changed, 326 insertions(+), 22 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index 9c1833b55f..9fc6660245 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -61,3 +61,9 @@ NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./to ```shell NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph --model.flavor=debugmodel_flex_attn ``` + +**SimpleFSDP + TP + Full Inductor compilation** + +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train ./run_train.sh --model.name $MODEL_NAME compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.joint_passes inductor_decomposition --compile.passes full_inductor_compilation +``` diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 011bbe402a..13e6689563 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -82,7 +82,7 @@ def parallelize_deepseekv3( # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(model, job_config) - # Create compilers with specified passes (defaults to no passes) + # Create compilers with specified passes fw_compiler, bw_compiler = make_compiler_with_passes( compiler_passes, dump_folder=job_config.job.dump_folder ) @@ -94,6 +94,7 @@ def parallelize_deepseekv3( bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, + job_config=job_config, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 551bf695c5..64dc03c312 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -39,6 +39,15 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No def export_joint( model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + dump_folder: Optional folder to dump the graph to + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -68,6 +77,14 @@ def export_joint( def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -79,6 +96,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): args, kwargs, ) + return joint_with_descriptors @@ -90,6 +108,7 @@ def joint_graph_builder( bw_compiler: Optional[Callable] = None, joint_custom_passes: Optional[List[Callable]] = None, dump_folder: str | None = None, + job_config: Optional["JobConfig"] = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -102,16 +121,41 @@ def joint_graph_builder( bw_compiler: Optional custom backward compiler function joint_custom_passes: list of custom passes to run on the joint graph dump_folder: Optional folder to dump the graph to + job_config: Job configuration """ assert isinstance(model_args, tuple) for idx, arg in enumerate(model_args): assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph - ( - joint_with_descriptors, - tracing_context, - ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) + (joint_with_descriptors, tracing_context,) = export_joint( + model, + model_args, + model_kwargs, + dump_folder=dump_folder, + ) + + # Check if inductor_decomposition is configured and create the pass with proper context + if job_config is not None: + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + if "inductor_decomposition" in joint_pass_names: + from torchtitan.experiments.compiler_toolkit.passes import ( + inductor_decomposition_pass, + ) + + # Create the decomposition pass with context + decomp_pass = functools.partial( + inductor_decomposition_pass, + model=model, + joint_with_descriptors=joint_with_descriptors, + forward_inputs=model_args, + tracing_context=tracing_context, + ) + + # Prepend to joint_custom_passes + if joint_custom_passes is None: + joint_custom_passes = [] + joint_custom_passes = [decomp_pass] + joint_custom_passes # run custom passes on joint-graph before partitioner if joint_custom_passes is not None: @@ -259,28 +303,36 @@ def compiler( logger.info(f"Applying pass: {pass_name}") gm = pass_fn(gm, example_inputs) - logger.debug(f"{name} after compiler:") - logger.debug( - gm.print_readable(print_output=False, include_stride=True, include_device=True) - ) - _dump_gm(dump_folder, gm, f"{name}_after_compiler") + # Only try to print/dump if gm is still a GraphModule + # (compile_fx_inner returns a CompiledFxGraph which doesn't have print_readable) + if hasattr(gm, "print_readable"): + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") + return gm def make_compiler_with_passes( - passes: List[Callable] = None, dump_folder: str | None = None + passes: List[Callable] = None, + dump_folder: str | None = None, ): """ Create forward and backward compilers with specified passes. Args: passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump graphs Returns: Tuple of (fw_compiler, bw_compiler) functions """ - def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + def fw_compiler(gm: torch.fx.GraphModule, example_inputs): return compiler( "fwd_gm", gm, @@ -290,7 +342,7 @@ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: is_forward=True, ) - def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + def bw_compiler(gm: torch.fx.GraphModule, example_inputs): return compiler( "bwd_gm", gm, @@ -303,7 +355,17 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: return fw_compiler, bw_compiler -def validate_pass_names(pass_names: list[str]) -> None: +def validate_pass_names(pass_names: list[str], joint_pass_names: list[str]) -> None: + """ + Validate compiler and joint pass names and their dependencies. + + Args: + pass_names: List of compiler pass names + joint_pass_names: List of joint custom pass names + + Raises: + ValueError: If pass configuration is invalid + """ if "cudagraph" in pass_names: assert ( pass_names[-1] == "cudagraph" @@ -317,13 +379,22 @@ def validate_pass_names(pass_names: list[str]) -> None: "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" ) + # Validate that full_inductor_compilation requires inductor_decomposition + if "full_inductor_compilation" in pass_names: + if "inductor_decomposition" not in joint_pass_names: + raise ValueError( + "full_inductor_compilation pass requires inductor_decomposition to be " + "specified in joint_passes. Please add --compile.joint_passes inductor_decomposition" + ) + def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): """ Extract and validate compiler passes from job config. Args: - job_config: Job configuration containing compile.passes + model: The model being compiled + job_config: Job configuration containing compile.passes and compile.joint_passes Returns: List of compiler pass functions @@ -334,9 +405,18 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi ) pass_names = getattr(job_config.compile, "passes", []) - validate_pass_names(pass_names) + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + + validate_pass_names(pass_names, joint_pass_names) compiler_passes = [] + # Warn if full Inductor compilation is enabled + if "full_inductor_compilation" in pass_names: + logger.warning( + "Full Inductor compilation is enabled. Note that Inductor may change numerics " + "and does not guarantee bitwise equivalent results compared to eager mode." + ) + for pass_name in pass_names: if pass_name not in AVAILABLE_COMPILER_PASSES: raise ValueError( @@ -360,18 +440,26 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi def get_joint_custom_passes_from_config( - parallel_dims: ParallelDims, job_config: JobConfig + parallel_dims: ParallelDims, + job_config: JobConfig, ): """ Extract and validate joint custom passes from job config. + Note: The inductor_decomposition pass is handled separately in joint_graph_builder + because it requires context (model, joint_with_descriptors, etc.) that's only + available at graph capture time. + Args: + parallel_dims: Parallelism dimensions job_config: Job configuration containing parallelism.fsdp_reshard_after_forward + and compile.joint_passes Returns: List of joint custom pass functions """ from torchtitan.experiments.compiler_toolkit.passes import ( + AVAILABLE_JOINT_PASSES, fsdp_reshard_after_fwd_pass, validate_flex_attn_annotation_pass, ) @@ -379,6 +467,25 @@ def get_joint_custom_passes_from_config( joint_custom_passes = [] joint_custom_passes.append(validate_flex_attn_annotation_pass) + # Handle joint passes from config (excluding inductor_decomposition) + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + for pass_name in joint_pass_names: + if pass_name not in AVAILABLE_JOINT_PASSES: + raise ValueError( + f"Unknown joint pass: {pass_name}. " + f"Available joint passes: {list(AVAILABLE_JOINT_PASSES.keys())}" + ) + + # Skip inductor_decomposition - it's handled in joint_graph_builder + if pass_name == "inductor_decomposition": + continue + + joint_custom_passes.append(AVAILABLE_JOINT_PASSES[pass_name]) + + if joint_pass_names: + logger.info(f"Using joint passes from config: {joint_pass_names}") + + # Handle FSDP reshard after forward match job_config.parallelism.fsdp_reshard_after_forward: case "always": fsdp_reshard_after_forward = True diff --git a/torchtitan/experiments/compiler_toolkit/job_config.py b/torchtitan/experiments/compiler_toolkit/job_config.py index ec5829a6c9..7db461b984 100644 --- a/torchtitan/experiments/compiler_toolkit/job_config.py +++ b/torchtitan/experiments/compiler_toolkit/job_config.py @@ -10,11 +10,22 @@ @dataclass class Compile: """ - List of compiler pass names to apply in the compiler toolkit workflow. - By default, no passes are applied. - Example: --compile.passes autobucketing_reordering,regional_inductor + Compiler configuration for the compiler toolkit workflow. + + - joint_passes: List of joint graph pass names to apply on the joint forward-backward + graph before partitioning. + + Example: --compile.joint_passes inductor_decomposition + + - passes: List of compiler pass names to apply to the partitioned forward/backward graphs. + + Example: --compile.passes full_inductor_compilation + + Note: If "full_inductor_compilation" is specified, "inductor_decomposition" must + be included in joint_passes. """ + joint_passes: list[str] = field(default_factory=list) passes: list[str] = field(default_factory=list) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 68fa7443f4..c955dc02f0 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -69,7 +69,7 @@ def parallelize_llama( # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(model, job_config) - # Create compilers with specified passes (defaults to no passes) + # Create compilers with specified passes fw_compiler, bw_compiler = make_compiler_with_passes( compiler_passes, dump_folder=job_config.job.dump_folder ) @@ -81,6 +81,7 @@ def parallelize_llama( bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, + job_config=job_config, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py index 5657eb2b2b..1e7354deff 100644 --- a/torchtitan/experiments/compiler_toolkit/passes.py +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -9,11 +9,18 @@ This module provides various compiler passes that can be applied to graph modules during compilation. Passes can be selected and configured via job config. + +Pass Types: +- Joint custom passes: Applied to the joint forward-backward graph before partitioning +- Compiler passes: Applied to the partitioned forward/backward graphs """ from typing import Any, Sequence import torch +from torch._functorch.aot_autograd import JointWithDescriptors +from torch._guards import TracingContext +from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torch.fx.passes.regional_inductor import regional_inductor @@ -24,6 +31,7 @@ from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( annotate_fsdp_all_gather, ) +from torchtitan.tools.logging import logger def autobucketing_reordering_pass( @@ -106,10 +114,166 @@ def fsdp_reshard_after_fwd_pass( return gm +def inductor_decomposition_pass( + gm: torch.fx.GraphModule, + model: torch.nn.Module, + joint_with_descriptors: JointWithDescriptors, + forward_inputs: tuple, + tracing_context: TracingContext, +) -> torch.fx.GraphModule: + """ + Apply Inductor decompositions to the joint graph. + + This pass applies decompositions to the joint forward-backward graph using make_fx. + It unwraps tensor subclasses (like DTensor) and retraces the graph with decompositions + applied, while preserving metadata required by the partitioner. + + Args: + gm: The joint graph module + model: The parallelized model + joint_with_descriptors: The joint graph with descriptors + forward_inputs: Forward input arguments (may be DTensors) + tracing_context: The tracing context from original joint graph capture + + Returns: + The joint graph with decompositions applied + """ + from torch._functorch._aot_autograd.descriptors import DummyAOTInput + from torch._functorch._aot_autograd.subclass_utils import unwrap_tensor_subclasses + from torch._inductor.decomposition import select_decomp_table + from torch.fx.experimental.proxy_tensor import make_fx + + logger.info("Applying decompositions to joint graph") + + decomp_table = select_decomp_table() + + # Get traced tangents metadata + traced_tangents = joint_with_descriptors._aot_state.fw_metadata.traced_tangents + + # Collect all inputs: params, buffers, forward inputs, tangents + param_inputs = list(model.parameters()) + buffer_inputs = list(model.buffers()) + primals = param_inputs + buffer_inputs + list(forward_inputs) + tangents = list(traced_tangents) + + # Create dummy descriptors for unwrapping + primals_descs = [DummyAOTInput(i) for i in range(len(primals))] + tangents_descs = [DummyAOTInput(i + len(primals)) for i in range(len(tangents))] + + # Unwrap tensor subclasses (DTensor -> _local_tensor) + primals_unwrapped, _ = unwrap_tensor_subclasses( + primals, primals_descs, append_symints=False + ) + tangents_unwrapped, _ = unwrap_tensor_subclasses( + tangents, tangents_descs, append_symints=False + ) + + # Verify unwrapped tensor shapes match joint graph placeholders + all_inputs = primals_unwrapped + tangents_unwrapped + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + if len(all_inputs) != len(placeholders): + raise RuntimeError( + f"Input count mismatch: {len(all_inputs)} inputs vs {len(placeholders)} placeholders" + ) + + shape_mismatches = [] + for i, (inp, ph) in enumerate(zip(all_inputs, placeholders)): + if hasattr(inp, "shape") and "val" in ph.meta: + expected_shape = ph.meta["val"].shape + actual_shape = inp.shape + if expected_shape != actual_shape: + shape_mismatches.append( + f" {ph.target}: expected {expected_shape}, got {actual_shape}" + ) + + if shape_mismatches: + logger.error(f"Shape mismatches found ({len(shape_mismatches)}):") + for msg in shape_mismatches: + logger.error(msg) + raise RuntimeError( + "Unwrapped tensor shapes don't match joint graph placeholders." + ) + + # Get the FakeTensorMode from the original joint graph + fake_mode = None + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "fake_mode"): + fake_mode = val.fake_mode + break + + if fake_mode is None: + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(primals_unwrapped) + + # Use make_fx with the original fake mode to retrace with decompositions + with fake_mode: + decomposed_gm = make_fx( + gm, + decomposition_table=decomp_table, + _allow_non_fake_inputs=False, + )(primals_unwrapped, tangents_unwrapped) + + # Copy metadata from original placeholders to decomposed placeholders + orig_placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + decomp_placeholders = [ + n for n in decomposed_gm.graph.nodes if n.op == "placeholder" + ] + + if len(orig_placeholders) != len(decomp_placeholders): + raise RuntimeError( + f"Placeholder count mismatch: {len(orig_placeholders)} vs {len(decomp_placeholders)}" + ) + + for orig, decomp in zip(orig_placeholders, decomp_placeholders): + # Copy all metadata from original to decomposed + for key, value in orig.meta.items(): + if key not in decomp.meta: + decomp.meta[key] = value + + # Rename decomposed placeholder to match original name + decomp.target = orig.target + decomp.name = orig.name + + decomposed_gm.recompile() + logger.info("Decompositions applied successfully to joint graph") + + return decomposed_gm + + +def full_inductor_compilation_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply full Inductor compilation with code generation. + + This pass uses compile_fx_inner to generate optimized code for the graph. + + Args: + gm: The graph module (forward or backward) + example_inputs: Example inputs for compilation + + Returns: + The compiled graph module + """ + return compile_fx_inner(gm, example_inputs) + + # Registry mapping pass names to pass functions AVAILABLE_COMPILER_PASSES = { "autobucketing_reordering": autobucketing_reordering_pass, "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, "regional_inductor": regional_inductor_pass, "cudagraph": cudagraph_pass, + "full_inductor_compilation": full_inductor_compilation_pass, +} + +# Registry for joint custom passes (applied before partitioning) +AVAILABLE_JOINT_PASSES = { + "inductor_decomposition": inductor_decomposition_pass, + "fsdp_reshard_after_fwd": fsdp_reshard_after_fwd_pass, + "validate_flex_attn_annotation": validate_flex_attn_annotation_pass, } diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index 9527d7dd23..8053efe3d4 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -122,7 +122,21 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--model.name compiler_toolkit.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--model.flavor debugmodel_flex_attn", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.joint_passes inductor_decomposition", + "--compile.passes full_inductor_compilation", + ], + ], + "llama3 full_inductor_compilation", + "llama3_full_inductor_compilation", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", "--compile.passes transformer_block_bucketing,regional_inductor", ], From ec246c97c3f53365a5ab8f42f0b586c1f9edfd7b Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 14 Jan 2026 12:50:50 -0800 Subject: [PATCH 095/127] [autoparallel] Update local_map_deepseek_v3 device mesh usage (#2231) Stacked PRs: * __->__#2231 --- --- --- Follow the new device mesh convention from https://github.com/pytorch/torchtitan/pull/1660/ for the local_map_deepseek_v3 config (wasn't covered by CI) This lets the configs get further, but it still errors --- torchtitan/experiments/autoparallel/README.md | 2 +- .../parallelize_deepseekv3.py | 24 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/torchtitan/experiments/autoparallel/README.md b/torchtitan/experiments/autoparallel/README.md index 570237b4d9..54f3c95fb7 100644 --- a/torchtitan/experiments/autoparallel/README.md +++ b/torchtitan/experiments/autoparallel/README.md @@ -16,7 +16,7 @@ Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://githu **DeepSeekv3** -`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` +`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` **DeepSeekv3 local_map** diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py index eb400484f6..5db38e841f 100644 --- a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py @@ -48,20 +48,24 @@ def parallelize_deepseekv3( job_config.experimental.comms_bucket_reorder_strategy ) - world_mesh = parallel_dims.world_mesh + # Build the sparse mesh for MoE expert parallelism + # Filter to only include enabled mesh dimensions + sparse_names = ["dp_replicate", "efsdp", "ep", "etp"] + sparse_names = [ + name + for name in sparse_names + if parallel_dims.get_optional_mesh(name) is not None + ] + sparse_mesh = parallel_dims.get_mesh(sparse_names) # Update me when changing dsv3.py - assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" - assert world_mesh.mesh_dim_names == ( - "dp_shard_mod_ep", - "dp_shard_in_ep", - ), "Current setup assumes these specific meshes" + assert sparse_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" # Provide AP MoE with mesh for layer in model.layers.values(): if layer.moe_enabled: - layer.moe.mesh = world_mesh - layer.moe.axis_name = "dp_shard_in_ep" + layer.moe.mesh = sparse_mesh + layer.moe.axis_name = "ep" def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -89,7 +93,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + sparse_mesh, mp_policy=mp_policy, compile=should_compile, dynamic=True, @@ -122,7 +126,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, sparse_mesh["etp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on From c26ea602b6be3159ca9ab06cbbde18e3703b4c33 Mon Sep 17 00:00:00 2001 From: Jeffrey Wan Date: Wed, 14 Jan 2026 18:32:28 -0500 Subject: [PATCH 096/127] Disable dynamo LRU cache when AC is enabled (#2204) Previously this config is flipped whenever compile is enabled. This PR updates it so that the config is flipped whenever AC is applied, so that the only flex compiled case is also covered. Code comment: ``` # Disable dynamo LRU cache to workaround an interaction between SAC, PP, and Flex: # # When forward runs with a second PP microbatch, it triggers recompilation with dynamic # shapes enabled. Now there are two valid compiled graphs. By default, dynamo selects # the latest one (the dynamic shapes version), so the runtime wrapper expects an extra # symint output. When SAC caches the inductor HOP output from the static graph for # batch_idx=0, it would miss that symint and cause an assertion failure. The workaround # here is to disable the LRU cache, and select graphs in insertion order instead. # # Also see: https://github.com/pytorch/pytorch/issues/166926 ``` --- torchtitan/distributed/activation_checkpoint.py | 13 +++++++++++++ torchtitan/models/llama4/infra/parallelize.py | 3 --- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index c0b550a5c1..9107b8bc73 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -210,6 +210,19 @@ def apply_ac( Returns: None """ + # Disable dynamo LRU cache to workaround an interaction between SAC, PP, and Flex: + # + # When forward runs with a second PP microbatch, it triggers recompilation with dynamic + # shapes enabled. Now there are two valid compiled graphs. By default, dynamo selects + # the latest one (the dynamic shapes version), so the runtime wrapper expects an extra + # symint output. When SAC caches the inductor HOP output from the static graph for + # batch_idx=0, it would miss that symint and cause an assertion failure. The workaround + # here is to disable the LRU cache, and select graphs in insertion order instead. + # + # Also see: https://github.com/pytorch/pytorch/issues/166926 + # pyrefly: ignore [missing-attribute] + torch._C._dynamo.eval_frame._set_lru_cache(False) + if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 154c48d992..ca878eab8e 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -592,9 +592,6 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. torch._dynamo.config.capture_scalar_outputs = True - # Workaround for https://github.com/pytorch/pytorch/issues/166926 - # pyrefly: ignore [missing-attribute] - torch._C._dynamo.eval_frame._set_lru_cache(False) # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): if transformer_block.moe_enabled: From 6408426b63e0f20a4d625c4ff78c0b35e9faa934 Mon Sep 17 00:00:00 2001 From: Frost Mitchell Date: Thu, 15 Jan 2026 05:09:05 -0500 Subject: [PATCH 097/127] Enable memory snapshot for generic devices (#2228) Pytorch recently added `_record_memory_history` for Intel XPU. This PR makes the memory snapshot device-generic. @pkourdis @githubsgi --- scripts/generate/test_generate.py | 1 - torchtitan/components/metrics.py | 7 ------- torchtitan/distributed/dual_pipe_v.py | 3 +-- torchtitan/distributed/utils.py | 2 -- torchtitan/tools/profiling.py | 5 +++-- torchtitan/tools/utils.py | 3 ++- torchtitan/train.py | 1 - 7 files changed, 6 insertions(+), 16 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index ef310b5996..e91b4c34a0 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -98,7 +98,6 @@ def test_generate( world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"{device_type}:{local_rank}") - # pyrefly: ignore [missing-attribute] device_module.set_device(device) device_memory_monitor = build_device_memory_monitor() diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 93c1e3b10f..82332d54ee 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -42,19 +42,14 @@ class DeviceMemoryMonitor: def __init__(self, device: str = f"{device_type}:0"): # pyrefly: ignore [read-only] self.device = torch.device(device) # device object - # pyrefly: ignore [missing-attribute] self.device_name = device_module.get_device_name(self.device) - # pyrefly: ignore [missing-attribute] self.device_index = device_module.current_device() - # pyrefly: ignore [missing-attribute] self.device_capacity = device_module.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) - # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() - # pyrefly: ignore [missing-attribute] device_module.empty_cache() def _to_gib(self, memory_in_bytes): @@ -67,7 +62,6 @@ def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): - # pyrefly: ignore [missing-attribute] device_info = device_module.memory_stats(self.device) max_active = device_info.get("active_bytes.all.peak", -1) @@ -98,7 +92,6 @@ def get_peak_stats(self): ) def reset_peak_stats(self): - # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index 5def0e40e6..ab168c377b 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -260,14 +260,13 @@ def overlap_callback(action: _Action, ctx: _PipelineContext): ) # PP computation ======================================================== _hook_coordinator.enable_coordination(num_layers=min_num_layers) - main_stream = torch.accelerator.current_stream(device_module) + main_stream = torch.accelerator.current_stream(device_type) # Shared container for exception from backward thread def run_backward(): # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(backward_stage) # Set the backward thread to use the same stream as forward - # pyrefly: ignore [missing-attribute] device_module.set_stream(main_stream) with record_function( f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7790ab6683..ca25134ec6 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -382,9 +382,7 @@ def set_pg_timeouts( # otherwise, some ranks may issue collectives with the new/shorter timeout and # those may time out, before other ranks have finished with initialization done # under the old/slow timeout. - # pyrefly: ignore [missing-attribute] torch.distributed.barrier(device_ids=[device_module.current_device()]) - # pyrefly: ignore [missing-attribute] device_module.synchronize() # None represents the 'default' PG, not part of the mesh diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 5c2b40b217..da91afc3c4 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -13,6 +13,7 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger +from torchtitan.tools.utils import device_module # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -104,7 +105,7 @@ def maybe_enable_memory_snapshot( class MemoryProfiler: def __init__(self, step_num: int, freq: int): - torch.cuda.memory._record_memory_history( + device_module.memory._record_memory_history( max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES ) # when resume training, we start from the last step @@ -131,7 +132,7 @@ def step(self, exit_ctx: bool = False): curr_snapshot_dir, f"rank{rank}_memory_snapshot.pickle" ) with open(output_file, "wb") as output: - pickle.dump(torch.cuda.memory._snapshot(), output) + pickle.dump(device_module.memory._snapshot(), output) logger.info( f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" ) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index d2fa409223..b7bf3e44de 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -9,6 +9,7 @@ import subprocess import time from dataclasses import dataclass +from types import ModuleType from typing import Generator, Optional import torch @@ -24,7 +25,7 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -def get_device_info() -> tuple[str, torch.device]: +def get_device_info() -> tuple[str, ModuleType]: device_type = _get_available_device_type() or "cuda" device_module = _get_device_module(device_type) # default device_module:torch.cuda return device_type, device_module diff --git a/torchtitan/train.py b/torchtitan/train.py index 8455e54eb9..a4f966f5e0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -85,7 +85,6 @@ def __init__(self, job_config: JobConfig): # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. - # pyrefly: ignore [missing-attribute] device_module.set_device(self.device) # init distributed and build meshes From 92401723a616463268a205f7cb96bfdae559b504 Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Thu, 15 Jan 2026 10:27:58 -0800 Subject: [PATCH 098/127] Add test for dsv3 with flexattn + fsdp + ep + pp + sac op (#2234) The titled composability issue was fixed by commit https://github.com/pytorch/torchtitan/commit/c26ea602b6be3159ca9ab06cbbde18e3703b4c33, and this pr adds a corresponding test as a guard. --- tests/integration_tests/models.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 665c9ef44e..5ba1c18c59 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -79,6 +79,23 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "deepseek_v3_pp+fsdp+tp+ep+etp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--model.name deepseek_v3", + "--model.flavor debugmodel_flex_attn", + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule Interleaved1F1B", + "--parallelism.expert_parallel_degree 4", + "--activation_checkpoint.mode 'selective'", + "--activation_checkpoint.selective_ac_option 'op'", + ], + ], + "DeepSeek V3 Flex+PP+FSDP+EP+SACOP", + "deepseek_v3_flex+pp+fsdp+ep+sacop", + ngpu=8, + ), # Integration Test Cases for Qwen3 dense and MoE model OverrideDefinitions( [ From 5ef90faf24f2da07a93c286aa3727177aaf94257 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 15 Jan 2026 15:19:15 -0800 Subject: [PATCH 099/127] [lint] ignore all existing pyrefly errors (#2240) [lint] ignore all existing pyrefly errors given that we tell people to mute them when developing locally --- scripts/checkpoint_conversion/convert_from_hf.py | 3 ++- scripts/checkpoint_conversion/convert_to_hf.py | 3 ++- .../numerical_tests_example.py | 3 +++ scripts/estimate/estimation.py | 1 + scripts/generate/test_generate.py | 5 +++-- torchtitan/components/checkpoint.py | 5 ++++- torchtitan/components/tokenizer.py | 7 ++++--- torchtitan/config/manager.py | 4 ++-- torchtitan/distributed/deepep/deepep.py | 7 ++----- torchtitan/distributed/dual_pipe_v.py | 12 +++++++----- torchtitan/distributed/expert_parallel.py | 5 +++-- torchtitan/distributed/utils.py | 16 +++++++++------- torchtitan/models/flux/inference/sampling.py | 1 + torchtitan/models/flux/model/autoencoder.py | 8 ++++---- torchtitan/models/flux/model/hf_embedder.py | 2 ++ torchtitan/models/flux/tokenizer.py | 2 ++ torchtitan/models/flux/train.py | 2 +- torchtitan/models/flux/validate.py | 4 ++-- torchtitan/models/gpt_oss/model/moe.py | 4 ++-- torchtitan/models/llama3/infra/parallelize.py | 4 +++- torchtitan/models/llama3/model/model.py | 6 +++--- torchtitan/models/llama4/infra/parallelize.py | 1 + torchtitan/models/llama4/model/model.py | 6 +++--- torchtitan/models/moe/moe.py | 3 ++- torchtitan/models/qwen3/model/model.py | 10 +++++----- torchtitan/protocols/model_converter.py | 4 +++- torchtitan/train.py | 2 ++ 27 files changed, 78 insertions(+), 52 deletions(-) diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index 77bfeddd59..1d475802b6 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -21,11 +21,12 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor): model_args = train_spec.model_args[model_flavor] with torch.device("cpu"): + # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) - # pyrefly: ignore [not-callable] + # pyrefly: ignore[bad-instantiation, not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, None) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index e68a6d2acc..ed28ffc28a 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -29,11 +29,12 @@ def convert_to_hf( model_args = train_spec.model_args[model_flavor] with torch.device("cpu"): + # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) - # pyrefly: ignore [not-callable] + # pyrefly: ignore[bad-instantiation, not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py index f52851ef9b..f3290b7e2a 100644 --- a/scripts/checkpoint_conversion/numerical_tests_example.py +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -14,6 +14,8 @@ from torchtitan.config import ConfigManager from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools.logging import logger + +# pyrefly: ignore[import-error] from transformers import AutoModelForCausalLM device_type = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,6 +73,7 @@ def forward_tt(config_path, checkpoint_path, test_set): model_args = train_spec.model_args[config.model.flavor] model_args.update_from_config(config) + # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) # materalize model diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index bfa9dddfd2..8f390c3de0 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -94,6 +94,7 @@ def estimate_memory(job_config: JobConfig): f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): + # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index e91b4c34a0..085ef18e3a 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -36,7 +36,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -# pyrefly: ignore [missing-import] +# pyrefly: ignore[import-error] from generate._generation import generate @@ -115,6 +115,7 @@ def test_generate( init_device = "meta" if world_size > 1 else device with torch.device(init_device): logger.info(f"Init model on init_device: {init_device}") + # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) parallel_dims = None @@ -233,7 +234,7 @@ def test_generate( "input_text": input_text, "output_text": output_text, } - output_data["responses"].append(_data) + output_data["responses"].append(_data) # pyrefly: ignore[missing-attribute] logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 7a7149a061..c12e2269ef 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -521,6 +521,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None: if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: + # pyrefly: ignore[bad-assignment] self.stager = DefaultStager(StagingOptions(True, True, True, True)) result = self.dcp_save( states, @@ -534,6 +535,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None: self.staging = True elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") + # pyrefly: ignore[bad-assignment] self.save_future = self.dcp_save( states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) @@ -711,6 +713,7 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + # pyrefly: ignore[bad-assignment] self.save_future = self.dcp_save( self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) @@ -836,7 +839,7 @@ def _async_wait(self) -> None: ): if self.save_future is not None: self.save_future.result() - self.save_future = None + self.save_future = None # pyrefly: ignore[bad-assignment] elif self.save_future is not None: raise RuntimeError( "self.save_future is not None, but self.async_mode is not enabled " diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index aca2300abe..0931fae09f 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -221,14 +221,15 @@ def _process_special_token( # Store BOS/EOS tokens as class attributes if they match if token_str == config_bos_token: - self.bos_token = token_str + self.bos_token = token_str # pyrefly: ignore[bad-assignment] self.bos_id = ( + # pyrefly: ignore[bad-assignment] token_id if token_id is not None else self.tokenizer.token_to_id(token_str) ) elif token_str == config_eos_token: - self.eos_token = token_str + self.eos_token = token_str # pyrefly: ignore[bad-assignment] self.eos_id = ( token_id if token_id is not None @@ -320,7 +321,7 @@ def _infer_should_add_bos_eos(self): # First, determine if underlying tokenizer auto-adds BOS/EOS tokens empirically encoded_empty_str = self.tokenizer.encode("").ids if self.bos_id is not None and self.bos_id in encoded_empty_str: - self.hf_adds_bos = True + self.hf_adds_bos = True # pyrefly: ignore[bad-assignment] if self.eos_id is not None and self.eos_id in encoded_empty_str: self.hf_adds_eos = True diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 79d95c350e..694d6fe8d0 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -16,7 +16,7 @@ try: import tomllib except ModuleNotFoundError: - # pyrefly: ignore [missing-import] + # pyrefly: ignore[import-error] import tomli as tomllib from torchtitan.tools.logging import logger @@ -179,7 +179,7 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: result[f.name] = self._dict_to_dataclass(f.type, value) else: result[f.name] = value - return cls(**result) + return cls(**result) # pyrefly: ignore[not-callable, bad-instantiation] def _validate_config(self) -> None: if self.config.experimental.custom_args_module: diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py index 9389fac5c5..99f7469930 100644 --- a/torchtitan/distributed/deepep/deepep.py +++ b/torchtitan/distributed/deepep/deepep.py @@ -18,11 +18,8 @@ from torch.distributed import ProcessGroup try: - from deep_ep import Buffer # pyrefly: ignore [missing-import] - from deep_ep.utils import ( # pyrefly: ignore [missing-import] - EventHandle, - EventOverlap, - ) + from deep_ep import Buffer # pyrefly: ignore[import-error] + from deep_ep.utils import EventHandle, EventOverlap # pyrefly: ignore[import-error] except ImportError as e: raise ImportError( "DeepEP is required for this module. " diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index ab168c377b..81d471a649 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -125,7 +125,7 @@ def enable_coordination(self, num_layers: Optional[int] = None): # Reset barrier self._execution_barrier = threading.Barrier(2) - self._num_layers = num_layers + self._num_layers = num_layers # pyrefly: ignore[bad-assignment] def disable_coordination(self): self._coordination_enabled = False @@ -300,10 +300,12 @@ def run_forward(): output = forward_stage.forward_one_chunk( # pyrefly: ignore [bad-argument-type] forward_mb_index, - # pyrefly: ignore [bad-index, unsupported-operation] - arg_mbs[forward_mb_index], - # pyrefly: ignore [bad-index, unsupported-operation] - kwarg_mbs[forward_mb_index], + arg_mbs[ + forward_mb_index + ], # pyrefly: ignore[index-error, unsupported-operation] + kwarg_mbs[ + forward_mb_index + ], # pyrefly: ignore[index-error, unsupported-operation] ) schedule._maybe_compute_loss( forward_stage, output, ctx.target_mbs, forward_mb_index diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 8ee53e754e..4395a4a3ea 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -155,9 +155,9 @@ def _token_dispatch( # of GroupedExperts, as it does not need padding. ( - self.input_shape, + self.input_shape, # pyrefly: ignore[bad-assignment] routed_input, - self.permuted_indices, + self.permuted_indices, # pyrefly: ignore[bad-assignment] num_tokens_per_expert_group, ) = _permute( routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts @@ -343,6 +343,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): num_local_experts = mod.w1.shape[0] ep_group = device_mesh.get_group() + # pyrefly: ignore[bad-assignment] hidden_states, tokens_per_expert, self._state = dispatch_tokens( hidden_states, selected_experts_indices, diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index ca25134ec6..ded4181006 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -467,9 +467,9 @@ def clip_grad_norm_( if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: - total_norm **= norm_type + total_norm **= norm_type # pyrefly: ignore[unsupported-operation] dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm @@ -493,7 +493,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - # pyrefly: ignore [not-iterable] + # pyrefly: ignore[unsupported-operation] if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) @@ -517,17 +517,19 @@ def _clip_grad_norm_with_ep( total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + # pyrefly: ignore[unsupported-operation] + ep_grads_total_norm**norm_type + + non_ep_grads_total_norm**norm_type ) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] if pp_mesh is not None: if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: - total_norm **= norm_type + total_norm **= norm_type # pyrefly: ignore[unsupported-operation] dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) diff --git a/torchtitan/models/flux/inference/sampling.py b/torchtitan/models/flux/inference/sampling.py index 5ee48ab60f..4c8c4ce993 100644 --- a/torchtitan/models/flux/inference/sampling.py +++ b/torchtitan/models/flux/inference/sampling.py @@ -36,6 +36,7 @@ def time_shift(mu: float, sigma: float, t: Tensor): + # pyrefly: ignore[unsupported-operation] return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) diff --git a/torchtitan/models/flux/model/autoencoder.py b/torchtitan/models/flux/model/autoencoder.py index 9ca46dff96..959c5a179c 100644 --- a/torchtitan/models/flux/model/autoencoder.py +++ b/torchtitan/models/flux/model/autoencoder.py @@ -191,11 +191,11 @@ def forward(self, x: Tensor) -> Tensor: hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - # pyrefly: ignore [bad-index, not-callable] + # pyrefly: ignore[index-error, not-callable] h = self.down[i_level].block[i_block](hs[-1]) # pyrefly: ignore [bad-argument-type] if len(self.down[i_level].attn) > 0: - # pyrefly: ignore [bad-index, not-callable] + # pyrefly: ignore[index-error, not-callable] h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: @@ -295,11 +295,11 @@ def forward(self, z: Tensor) -> Tensor: # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): - # pyrefly: ignore [bad-index, not-callable] + # pyrefly: ignore[index-error, not-callable] h = self.up[i_level].block[i_block](h) # pyrefly: ignore [bad-argument-type] if len(self.up[i_level].attn) > 0: - # pyrefly: ignore [bad-index, not-callable] + # pyrefly: ignore[index-error, not-callable] h = self.up[i_level].attn[i_block](h) if i_level != 0: # pyrefly: ignore [not-callable] diff --git a/torchtitan/models/flux/model/hf_embedder.py b/torchtitan/models/flux/model/hf_embedder.py index 89bed4d248..56d82b1d2f 100644 --- a/torchtitan/models/flux/model/hf_embedder.py +++ b/torchtitan/models/flux/model/hf_embedder.py @@ -7,6 +7,8 @@ import os from torch import nn, Tensor + +# pyrefly: ignore[import-error] from transformers import CLIPTextModel, T5EncoderModel diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index 06fbde2bbb..5d7544c4d1 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -11,6 +11,8 @@ from typing import List import torch + +# pyrefly: ignore[import-error] from transformers import CLIPTokenizer, T5Tokenizer from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index b6fea6dbe7..0519d26703 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -172,7 +172,7 @@ def forward_backward_step( loss = self.loss_fn(latent_noise_pred, target) # latent_noise_pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory - # pyrefly: ignore [unsupported-delete] + # pyrefly: ignore[delete-error] del (latent_noise_pred, noise, target) loss.backward() diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 3bfa204e74..9deb12a195 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -146,7 +146,7 @@ def validate( job_config=self.job_config, # pyrefly: ignore [bad-argument-type] model=model, - prompt=p, + prompt=p, # pyrefly: ignore[bad-argument-type] autoencoder=self.autoencoder, t5_tokenizer=self.t5_tokenizer, clip_tokenizer=self.clip_tokenizer, @@ -163,7 +163,7 @@ def validate( ), x=image, add_sampling_metadata=True, - prompt=p, + prompt=p, # pyrefly: ignore[bad-argument-type] ) save_img_count -= 1 diff --git a/torchtitan/models/gpt_oss/model/moe.py b/torchtitan/models/gpt_oss/model/moe.py index 2a60751ec3..6f8cd03c5c 100644 --- a/torchtitan/models/gpt_oss/model/moe.py +++ b/torchtitan/models/gpt_oss/model/moe.py @@ -223,7 +223,7 @@ def forward( tp_degree = 1 if isinstance(self.mlp1_weight, DTensor): mesh_dim_names = self.mlp1_weight.device_mesh.mesh_dim_names - # pyrefly: ignore [not-iterable] + # pyrefly: ignore[unsupported-operation] if "tp" in mesh_dim_names: # pyrefly: ignore [missing-attribute] tp_dim_idx = mesh_dim_names.index("tp") @@ -232,7 +232,7 @@ def forward( if self.use_grouped_mm: if ( not isinstance(self.mlp1_weight, DTensor) - # pyrefly: ignore [not-iterable] + # pyrefly: ignore[unsupported-operation] or "ep" not in self.mlp1_weight.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index a77a87a921..b6c339fa67 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -286,7 +286,7 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: - # pyrefly: ignore [bad-typed-dict-key] + # pyrefly: ignore[unsupported-operation] fsdp_config["offload_policy"] = CPUOffloadPolicy() match reshard_after_forward_policy: @@ -312,6 +312,7 @@ def apply_fsdp( ) # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): + # pyrefly: ignore[no-matching-overload] fully_shard( transformer_block, **fsdp_config, @@ -326,6 +327,7 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward_policy == "always", ) + # pyrefly: ignore[no-matching-overload] fully_shard(model, **fsdp_config) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 40767c06d2..0c201188a4 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -574,15 +574,15 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer( h, self.freqs_cis, attention_masks=attention_masks, positions=positions ) - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index ca878eab8e..c62635e2eb 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -594,6 +594,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b torch._dynamo.config.capture_scalar_outputs = True # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): + # pyrefly: ignore[missing-attribute] if transformer_block.moe_enabled: # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break # So we must weave compile wrappers around those FSDP hooks to diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 8ee54cae5b..ef17f6e087 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -596,14 +596,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.freqs_cis, attention_masks, positions) - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 90e6418972..6a1c220047 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -166,7 +166,7 @@ def forward( # otherwise, EP will handle the padding. if ( not isinstance(self.w1, DTensor) - # pyrefly: ignore [not-iterable] + # pyrefly: ignore[unsupported-operation] or "ep" not in self.w1.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) @@ -563,6 +563,7 @@ def init_weights( self.experts.num_experts, dtype=torch.float32 ) if self.load_balance_coeff is not None: + # pyrefly: ignore[bad-assignment] self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 389f0fc9f7..ff18bf9211 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -252,9 +252,9 @@ def forward( # Adding the q_norm and k_norm here # Last layer of adding q-k norm - if self.q_norm: + if self.q_norm: # pyrefly: ignore[invalid-argument] xq = self.q_norm(xq) - if self.k_norm: + if self.k_norm: # pyrefly: ignore[invalid-argument] xk = self.k_norm(xk) # Apply rotary embedding @@ -577,14 +577,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.rope_cache, attention_masks, positions) - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h - # pyrefly: ignore [not-callable] + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index dbfc3a99c3..cb4804be6f 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -62,7 +62,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): _registry_model_converter_cls[name] for name in job_config.model.converters ] self.converters = [ - mh_cls(job_config, parallel_dims) for mh_cls in converter_classes + # pyrefly: ignore[bad-instantiation] + mh_cls(job_config, parallel_dims) + for mh_cls in converter_classes ] self.print_after_conversion = job_config.model.print_after_conversion diff --git a/torchtitan/train.py b/torchtitan/train.py index a4f966f5e0..3c255ccdaa 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -143,6 +143,7 @@ def __init__(self, job_config: JobConfig): torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): + # pyrefly: ignore[bad-instantiation] model = self.train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty @@ -306,6 +307,7 @@ def __init__(self, job_config: JobConfig): states={"train_state": self}, checkpoint_config=job_config.checkpoint, sd_adapter=( + # pyrefly: ignore[bad-instantiation] self.train_spec.state_dict_adapter( model_args, job_config.model.hf_assets_path ) From 15569716e3736531ea49f0908bcf563c962584a5 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 15 Jan 2026 16:33:49 -0800 Subject: [PATCH 100/127] [Experimental][rl][vllm compat] Update simple_rl example to work with vLLM nightly (#2219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See title; running with the latest version of vLLM nightly ran into a few issues as these APIs have changed and within torch titan some paths have changed ### Test plan ``` VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.rl.vllm_compat.simple_rl ``` ### Results ``` ✓ vLLM-TorchTitan bitwise determinism verified: 100 tokens match exactly Step 2 | Loss: -0.0033 | Reward: +0.213 | Samples: 160 Sample: Natalia sold 48 clips in April. In May, she sold half as many as she sold in Ap... Converted to 311 vLLM weights ``` --- .../experiments/rl/vllm_compat/README.md | 19 ++++++++++++++----- .../vllm_compat/batch_invariant_backward.py | 6 +++++- .../rl/vllm_compat/models/attention.py | 4 ++-- .../experiments/rl/vllm_compat/simple_rl.py | 9 ++++++--- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/torchtitan/experiments/rl/vllm_compat/README.md b/torchtitan/experiments/rl/vllm_compat/README.md index bf56f4afbe..84df62d3ed 100644 --- a/torchtitan/experiments/rl/vllm_compat/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -51,8 +51,12 @@ Note: Currently supports single-device training only. ### Prerequisites ```bash -# Install vLLM with deterministic support -pip install vllm +# Install vLLM with deterministic support (from source) +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . # Install TorchTitan (from the repository root) pip install -e . @@ -75,15 +79,17 @@ init_batch_invariance() ### Quick Start ```python -import torch from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from vllm.v1.attention.backends.registry import AttentionBackendEnum + +import torch from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, Qwen3VLLMCompatModel, ) # 1. Enable deterministic mode -init_batch_invariance() +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) enable_batch_invariant_backward_mode() # 2. Load model @@ -95,7 +101,7 @@ model_args = Qwen3ModelArgs( n_kv_heads=2, vocab_size=151936, ) -model = Qwen3VLLMCompatModel(model_args) +model = Qwen3VLLMCompatModel(model_args).to('cuda').to(torch.bfloat16) # 3. Forward pass (deterministic) input_ids = torch.randint(0, 151936, (2, 128), device='cuda') @@ -104,6 +110,9 @@ logits = model(input_ids) # 4. Backward pass loss = logits.sum() loss.backward() + +print("Done running simple model") + ``` ### Full RL Training diff --git a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py index faccf8265d..b67244478e 100644 --- a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py @@ -62,10 +62,14 @@ def forward(ctx, x): Returns: output: silu(gate) * up, shape [..., hidden_dim] """ + from vllm.config import set_current_vllm_config, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul # Use vLLM's implementation for forward - vllm_silu_and_mul = VLLMSiluAndMul() + # vLLM custom ops require a config context to be set + # Since these are parameter free we instantiate default config + with set_current_vllm_config(VllmConfig()): + vllm_silu_and_mul = VLLMSiluAndMul() output = vllm_silu_and_mul(x) # Save for backward diff --git a/torchtitan/experiments/rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py index 3bcbe3071a..752b416922 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -8,7 +8,7 @@ import math import torch -from vllm.attention.utils.fa_utils import flash_attn_varlen_func +from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): @@ -17,8 +17,8 @@ class VLLMCompatibleFlashAttention(torch.nn.Module): def __init__(self) -> None: super().__init__() self.flash_attn_varlen_func = flash_attn_varlen_func - from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + from vllm.v1.attention.backends.fa_utils import get_flash_attn_version self.vllm_is_batch_invariant = vllm_is_batch_invariant self.fa_version = get_flash_attn_version() diff --git a/torchtitan/experiments/rl/vllm_compat/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py index 5e1fdd486b..bd9225bbe6 100644 --- a/torchtitan/experiments/rl/vllm_compat/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -39,8 +39,10 @@ from vllm import LLM, SamplingParams from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from vllm.v1.attention.backends.registry import AttentionBackendEnum -init_batch_invariance() + +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) class VLLMRolloutEngine: @@ -170,6 +172,7 @@ def update_weights(self, vllm_compat_state: dict) -> None: gpu_memory_utilization=0.3, # Reduced from 0.5 seed=42, # Fixed seed for determinism enforce_eager=True, + attention_config={"backend": AttentionBackendEnum.FLASH_ATTN}, ) print("✓ Created new vLLM engine") else: @@ -340,7 +343,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.models.qwen3 import ( + from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -1051,7 +1054,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( + from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, ) From a085b0e217e3218328b8e47eb15684f302d4c181 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Sun, 18 Jan 2026 22:33:57 -0800 Subject: [PATCH 101/127] [Experimental][rl][unified] Update infer.py example to work with vLLM nightly (#2226) NOTE: This PR should be landed on top of https://github.com/pytorch/torchtitan/pull/2219 Few changes seem to be needed to get working; namely: 1) Update to use `spawn` method for CUDA fork 2) Modify attention import to be correct after refactor 3) Modify attention dimensions to be correct after refactor ### Test Plan Follow steps in the readme then run ``` python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B ``` ### Output ``` [2026-01-13 15:20:58] INFO infer.py:113: Generation complete Prompt: Hello, my name is Generated text: " Josh and I'm in the middle of a project to develop a hybrid mobile app. I'm looking for guidance on how to go about using modular frameworks. I want to use React and Vue. I need to decide on the framework to use. Can you help me choose the right framework and suggest some best practices for using them?\n\nAdditionally, I want to know what are the best practices for using a modular framework in the context of a web application? Also, what are the best practices for using a" ``` --- torchtitan/experiments/rl/unified/infer.py | 8 +++++++- .../rl/unified/models/attention.py | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 43153fb70e..3e9470bf5d 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -5,11 +5,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os + +# Must set spawn method before any CUDA operations or vLLM imports +# CUDA cannot be re-initialized in forked subprocesses +# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + import argparse # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 - from vllm import LLM, SamplingParams from vllm.logger import init_logger diff --git a/torchtitan/experiments/rl/unified/models/attention.py b/torchtitan/experiments/rl/unified/models/attention.py index 1a03b882cb..0492af2ffd 100644 --- a/torchtitan/experiments/rl/unified/models/attention.py +++ b/torchtitan/experiments/rl/unified/models/attention.py @@ -75,17 +75,28 @@ def forward( output: [batch, num_heads, seq_len, head_dim] """ # Input is (batch, num_heads, seq_len, head_dim) + # TODO: may be good to use einops in future as we can explicitly reshape + # with dimension names - see https://github.com/arogozhnikov/einops batch_size, num_heads, seq_len, head_dim = q.shape + _, num_kv_heads, _, _ = k.shape - # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + # vLLM expects (num_tokens, num_heads, head_dim) where num_tokens = batch * seq_len + # First transpose to (batch, seq_len, num_heads, head_dim) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - output_varlen = self.vllm_attn(q, k, v) + # TODO: reimplement as a 4d tensor once vLLM fix has landed + # Then flatten batch and seq_len: (batch * seq_len, num_heads, head_dim) + q = q.reshape(batch_size * seq_len, num_heads, head_dim) + k = k.reshape(batch_size * seq_len, num_kv_heads, head_dim) + v = v.reshape(batch_size * seq_len, num_kv_heads, head_dim) - # Reshape back to batch format - output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + # vLLM attention returns (num_tokens, hidden_size) where hidden_size = num_heads * head_dim + output_flat = self.vllm_attn(q, k, v) + + # Output is (batch * seq_len, num_heads * head_dim), reshape to (batch, seq_len, num_heads, head_dim) + output = output_flat.view(batch_size, seq_len, num_heads, head_dim) # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) output = output.transpose(1, 2) From 09c6d74bf0535a283c5a2cf58dee5d5f701f39a0 Mon Sep 17 00:00:00 2001 From: francesco-bertolotti Date: Mon, 19 Jan 2026 08:08:03 +0100 Subject: [PATCH 102/127] fix sdpa-varlen attention mismatch in qwen3 (#2229) Fix for https://github.com/pytorch/torchtitan/issues/2223 --- torchtitan/models/llama3/model/model.py | 9 ++++++--- torchtitan/models/qwen3/model/model.py | 10 ++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 0c201188a4..b8fef9bdb7 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -281,6 +281,9 @@ def forward( case "flex": assert isinstance(attention_masks, BlockMask), attention_masks output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) case "varlen": assert isinstance(attention_masks, VarlenMetadata), attention_masks output = self.inner_attention( @@ -293,12 +296,12 @@ def forward( case "sdpa": assert attention_masks is None output = self.inner_attention(xq, xk, xv) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bs, seqlen, -1) return self.wo(output) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index ff18bf9211..a6a6e0407c 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -274,6 +274,9 @@ def forward( output = self.inner_attention( xq, xk, xv, block_mask=attention_masks, scale=self.scaling ) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) case "varlen": # TODO: pass self.scaling into varlen attention assert isinstance(attention_masks, VarlenMetadata), attention_masks @@ -288,13 +291,12 @@ def forward( case "sdpa": assert attention_masks is None output = self.inner_attention(xq, xk, xv, scale=self.scaling) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) case _: raise ValueError(f"Unknown attention type: {self.attn_type}") - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) return self.wo(output) From 2a642d022de2eb0e63be4ddd4d95619c3da0b518 Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:39:32 -0600 Subject: [PATCH 103/127] Update README with libnvshmem_host.so troubleshooting Added troubleshooting tip for missing libnvshmem_host.so. --- torchtitan/distributed/deepep/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtitan/distributed/deepep/README.md b/torchtitan/distributed/deepep/README.md index 94012de27e..3f05d5161a 100644 --- a/torchtitan/distributed/deepep/README.md +++ b/torchtitan/distributed/deepep/README.md @@ -175,6 +175,9 @@ uv pip install git+https://github.com/deepseek-ai/DeepEP.git --no-build-isolatio > > See [GitHub Issue #224](https://github.com/deepseek-ai/DeepEP/issues/224#issuecomment-2985783610) +> If you see /usr/bin/ld: cannot find -l:libnvshmem_host.so: No such file or directory +> try ln -s /path/to/libnvshmem_host.so.3 /path/to/libnvshmem_host.so + ### Step 3: Verify Installation ```bash From a25dd8f880e0759c1496d07ff0ca6973c6dc8b65 Mon Sep 17 00:00:00 2001 From: RuibinCheung Date: Tue, 20 Jan 2026 12:30:27 +0800 Subject: [PATCH 104/127] [ROCm] Support mxfp8 on gfx950. (#2222) * Support mxfp8 on gfx950. It depends on TorchAO (https://github.com/pytorch/ao/pull/3620). --- torchtitan/components/quantization/mx.py | 8 ++++---- torchtitan/tools/utils.py | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f1c0e09574..fbd69dbf63 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -19,7 +19,7 @@ from torchtitan.models.moe.utils import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import register_model_converter from torchtitan.tools.logging import logger -from torchtitan.tools.utils import has_cuda_capability +from torchtitan.tools.utils import has_cuda_capability, has_rocm_capability from .utils import module_filter_fn @@ -39,9 +39,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) # Can be removed if we enable the emulated versions - assert has_cuda_capability( - 10, 0 - ), "MXFP8 is only supported on SM100 or architectures" + assert has_cuda_capability(10, 0) or has_rocm_capability( + 9, 5 + ), "MXFP8 is only supported on CUDA SM100 or later, or ROCm gfx950 or later" # TP not yet supported with torch.compile model_compile_enabled = ( diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index b7bf3e44de..6af30cd130 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -25,6 +25,14 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) +def has_rocm_capability(major: int, minor: int) -> bool: + is_rocm = torch.cuda.is_available() and torch.version.hip is not None + return is_rocm and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + def get_device_info() -> tuple[str, ModuleType]: device_type = _get_available_device_type() or "cuda" device_module = _get_device_module(device_type) # default device_module:torch.cuda From 7fde8b6a27c4f04713ed6a72c962fb84cdafa1fb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 20 Jan 2026 12:44:06 -0800 Subject: [PATCH 105/127] [Typing] Fix CI Typing Issues (#2245) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2260 * #2246 * __->__ #2245 As title --- scripts/generate/test_generate.py | 2 +- torchtitan/config/manager.py | 21 ++++++++++++------- torchtitan/distributed/deepep/deepep.py | 7 +++++-- torchtitan/distributed/dual_pipe_v.py | 10 ++++----- torchtitan/distributed/utils.py | 2 +- torchtitan/models/flux/model/autoencoder.py | 8 +++---- torchtitan/models/flux/tokenizer.py | 2 +- torchtitan/models/flux/train.py | 2 +- torchtitan/models/gpt_oss/model/moe.py | 4 ++-- torchtitan/models/llama3/infra/parallelize.py | 2 +- torchtitan/models/moe/moe.py | 2 +- 11 files changed, 35 insertions(+), 27 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 085ef18e3a..c6c8612afa 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -36,7 +36,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -# pyrefly: ignore[import-error] +# pyrefly: ignore[missing-import] from generate._generation import generate diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 694d6fe8d0..b0e7c6561e 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -254,13 +254,20 @@ def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): # # ----------------------------------------------------------------------------- - # pyrefly: ignore [missing-import] - from rich import print as rprint + try: - # pyrefly: ignore [missing-import] - from rich.pretty import Pretty + # pyrefly: ignore[missing-import] + from rich import print as rprint - config_manager = ConfigManager() - config = config_manager.parse_args() + # pyrefly: ignore[missing-import] + from rich.pretty import Pretty - rprint(Pretty(config)) + config_manager = ConfigManager() + config = config_manager.parse_args() + + rprint(Pretty(config)) + except ImportError: + config_manager = ConfigManager() + config = config_manager.parse_args() + logger.info(config) + logger.warning("rich is not installed, show the raw config") diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py index 99f7469930..ce44fc232e 100644 --- a/torchtitan/distributed/deepep/deepep.py +++ b/torchtitan/distributed/deepep/deepep.py @@ -18,8 +18,11 @@ from torch.distributed import ProcessGroup try: - from deep_ep import Buffer # pyrefly: ignore[import-error] - from deep_ep.utils import EventHandle, EventOverlap # pyrefly: ignore[import-error] + from deep_ep import Buffer # pyrefly: ignore[missing-import] + from deep_ep.utils import ( # pyrefly: ignore[missing-import] + EventHandle, + EventOverlap, + ) except ImportError as e: raise ImportError( "DeepEP is required for this module. " diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index 81d471a649..9f13b7f958 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -300,12 +300,10 @@ def run_forward(): output = forward_stage.forward_one_chunk( # pyrefly: ignore [bad-argument-type] forward_mb_index, - arg_mbs[ - forward_mb_index - ], # pyrefly: ignore[index-error, unsupported-operation] - kwarg_mbs[ - forward_mb_index - ], # pyrefly: ignore[index-error, unsupported-operation] + # pyrefly: ignore[bad-index, unsupported-operation] + arg_mbs[forward_mb_index], + # pyrefly: ignore[bad-index, unsupported-operation] + kwarg_mbs[forward_mb_index], ) schedule._maybe_compute_loss( forward_stage, output, ctx.target_mbs, forward_mb_index diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index ded4181006..cd2d797c3a 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -493,7 +493,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[not-iterable] if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) diff --git a/torchtitan/models/flux/model/autoencoder.py b/torchtitan/models/flux/model/autoencoder.py index 959c5a179c..a50e4a5ba3 100644 --- a/torchtitan/models/flux/model/autoencoder.py +++ b/torchtitan/models/flux/model/autoencoder.py @@ -191,11 +191,11 @@ def forward(self, x: Tensor) -> Tensor: hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - # pyrefly: ignore[index-error, not-callable] + # pyrefly: ignore[bad-index, not-callable] h = self.down[i_level].block[i_block](hs[-1]) # pyrefly: ignore [bad-argument-type] if len(self.down[i_level].attn) > 0: - # pyrefly: ignore[index-error, not-callable] + # pyrefly: ignore[bad-index, not-callable] h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: @@ -295,11 +295,11 @@ def forward(self, z: Tensor) -> Tensor: # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): - # pyrefly: ignore[index-error, not-callable] + # pyrefly: ignore[bad-index, not-callable] h = self.up[i_level].block[i_block](h) # pyrefly: ignore [bad-argument-type] if len(self.up[i_level].attn) > 0: - # pyrefly: ignore[index-error, not-callable] + # pyrefly: ignore[bad-index, not-callable] h = self.up[i_level].attn[i_block](h) if i_level != 0: # pyrefly: ignore [not-callable] diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index 5d7544c4d1..aab4d68140 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -137,7 +137,7 @@ def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. """ - return self._tokenizer.decode(t) + return self._tokenizer.decode(t) # pyrefly: ignore[bad-return] def build_flux_tokenizer(job_config: JobConfig) -> tuple[BaseTokenizer, BaseTokenizer]: diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 0519d26703..91f2c9b5c2 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -172,7 +172,7 @@ def forward_backward_step( loss = self.loss_fn(latent_noise_pred, target) # latent_noise_pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory - # pyrefly: ignore[delete-error] + # pyrefly: ignore[unsupported-delete] del (latent_noise_pred, noise, target) loss.backward() diff --git a/torchtitan/models/gpt_oss/model/moe.py b/torchtitan/models/gpt_oss/model/moe.py index 6f8cd03c5c..f9f5b085bd 100644 --- a/torchtitan/models/gpt_oss/model/moe.py +++ b/torchtitan/models/gpt_oss/model/moe.py @@ -223,7 +223,7 @@ def forward( tp_degree = 1 if isinstance(self.mlp1_weight, DTensor): mesh_dim_names = self.mlp1_weight.device_mesh.mesh_dim_names - # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[not-iterable] if "tp" in mesh_dim_names: # pyrefly: ignore [missing-attribute] tp_dim_idx = mesh_dim_names.index("tp") @@ -232,7 +232,7 @@ def forward( if self.use_grouped_mm: if ( not isinstance(self.mlp1_weight, DTensor) - # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[not-iterable] or "ep" not in self.mlp1_weight.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b6c339fa67..01f76b71fc 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -286,7 +286,7 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: - # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[bad-typed-dict-key] fsdp_config["offload_policy"] = CPUOffloadPolicy() match reshard_after_forward_policy: diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 6a1c220047..922b08c7a3 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -166,7 +166,7 @@ def forward( # otherwise, EP will handle the padding. if ( not isinstance(self.w1, DTensor) - # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[not-iterable] or "ep" not in self.w1.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) From 42fd9036e5cc3c010a00ea4a1eb2c38b84490d9a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 20 Jan 2026 14:33:24 -0800 Subject: [PATCH 106/127] [Typing] Improve ModelProtocol typing (#2246) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2260 * __->__ #2246 * #2245 --- scripts/checkpoint_conversion/convert_from_hf.py | 1 - scripts/checkpoint_conversion/convert_to_hf.py | 1 - scripts/checkpoint_conversion/numerical_tests_example.py | 4 ---- scripts/estimate/estimation.py | 8 -------- scripts/generate/test_generate.py | 4 ---- tests/unit_tests/test_train_spec.py | 4 ++-- torchtitan/components/checkpoint.py | 1 - torchtitan/components/ft/manager.py | 1 - torchtitan/components/quantization/mx.py | 5 ----- torchtitan/components/tokenizer.py | 3 --- .../rl/vllm_compat/models/qwen3/model_vllm_compat.py | 4 ++-- torchtitan/models/deepseek_v3/model/model.py | 4 ++-- torchtitan/models/flux/model/model.py | 4 ++-- torchtitan/models/flux/tokenizer.py | 2 +- torchtitan/models/gpt_oss/infra/parallelize.py | 7 ------- torchtitan/models/gpt_oss/model/model.py | 4 ++-- torchtitan/models/llama3/infra/parallelize.py | 1 - torchtitan/models/llama3/model/model.py | 4 ++-- torchtitan/models/llama4/infra/parallelize.py | 1 - torchtitan/models/llama4/model/model.py | 4 ++-- torchtitan/models/qwen3/infra/parallelize.py | 1 - torchtitan/models/qwen3/model/model.py | 4 ++-- torchtitan/protocols/model.py | 8 +++++--- torchtitan/train.py | 2 -- 24 files changed, 22 insertions(+), 60 deletions(-) diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index 1d475802b6..fdc2abf8b2 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -23,7 +23,6 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor): with torch.device("cpu"): # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) - # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) # pyrefly: ignore[bad-instantiation, not-callable] diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index ed28ffc28a..c5b45b2e4a 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -31,7 +31,6 @@ def convert_to_hf( with torch.device("cpu"): # pyrefly: ignore[bad-instantiation] model = train_spec.model_cls(model_args) - # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) # pyrefly: ignore[bad-instantiation, not-callable] diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py index f3290b7e2a..45214732d2 100644 --- a/scripts/checkpoint_conversion/numerical_tests_example.py +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -78,13 +78,10 @@ def forward_tt(config_path, checkpoint_path, test_set): # materalize model device = torch.device(device_type) - # pyrefly: ignore [missing-attribute] model.to_empty(device=device) model.init_weights(buffer_device=device) - # pyrefly: ignore [missing-attribute] model.eval() - # pyrefly: ignore [bad-argument-type] modelWrapper = ModelWrapper(model) state_dict = modelWrapper._get_state_dict() @@ -100,7 +97,6 @@ def forward_tt(config_path, checkpoint_path, test_set): input_ids = input_ids.unsqueeze(0) # obtains the logits of only the last token in the predictions - # pyrefly: ignore [not-callable] predictions = model(input_ids)[:, -1, :].unsqueeze(1) output_list.append(predictions) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 8f390c3de0..d9c2a93693 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -99,21 +99,17 @@ def estimate_memory(job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) - # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing train_spec.parallelize_fn(model, parallel_dims, job_config) - # pyrefly: ignore [missing-attribute] model.to_empty(device="cuda") if not active_fake_mode(): model.init_weights() - # pyrefly: ignore [missing-attribute] model.train() # build optimizer after applying parallelisms to the model - # pyrefly: ignore [bad-argument-type] optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) lr_schedulers = build_lr_schedulers( # pyrefly: ignore [bad-argument-type] @@ -125,7 +121,6 @@ def estimate_memory(job_config: JobConfig): # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance optimizers.register_step_post_hook( - # pyrefly: ignore [bad-argument-type] lambda *args, **kwargs: model_converters.post_optimizer_hook(model) ) @@ -150,7 +145,6 @@ def estimate_memory(job_config: JobConfig): device="cuda", ), ) - # pyrefly: ignore [bad-argument-type] fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) @@ -160,7 +154,6 @@ def estimate_memory(job_config: JobConfig): input_ids, labels = batch # train step with train_context(): - # pyrefly: ignore [not-callable] pred = model(input_ids) loss = loss_fn(pred, labels) del pred @@ -168,7 +161,6 @@ def estimate_memory(job_config: JobConfig): # clip gradients torch.nn.utils.clip_grad_norm_( - # pyrefly: ignore [missing-attribute] model.parameters(), job_config.training.max_norm, foreach=True, diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index c6c8612afa..0db1dacc9d 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -135,7 +135,6 @@ def test_generate( # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - # pyrefly: ignore [bad-argument-type] apply_tp_minus_sp(model, parallel_dims.get_mesh("tp")) else: parallel_dims = ParallelDims( @@ -158,14 +157,11 @@ def test_generate( ) # materalize model - # pyrefly: ignore [missing-attribute] model.to_empty(device=device_type) with torch.no_grad(): model.init_weights() - # pyrefly: ignore [missing-attribute] model.eval() - # pyrefly: ignore [missing-attribute] state_dict = model.state_dict() # Checkpoint Loading diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 2f8986705e..07d5cd94e6 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -26,9 +26,9 @@ ) -class FakeModel(nn.Module, ModelProtocol): +class FakeModel(ModelProtocol): def __init__(self, model_args: BaseModelArgs) -> None: - super().__init__() + super().__init__(model_args) self.linear = nn.Linear(8, 8) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index c12e2269ef..f571c406bb 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -214,7 +214,6 @@ def __init__( ) if self.ft_manager and not self.enable_ft_dataloader_checkpoints: - # pyrefly: ignore [deprecated] logger.warning( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index d95470c47d..03778dd6d0 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -165,5 +165,4 @@ def maybe_semi_sync_training( raise ValueError( f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported." ) - # pyrefly: ignore [no-matching-overload] return nullcontext() diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index fbd69dbf63..3bdd250c15 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -57,19 +57,14 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): MXLinearConfig as TorchAOMXLinearConfig, ) - # pyrefly: ignore [bad-assignment] mx_job_config: TorchAOMXLinearConfig = job_config.quantize.linear.mx - # pyrefly: ignore [missing-attribute] config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name) - # pyrefly: ignore [missing-attribute] config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ] - # pyrefly: ignore [missing-attribute] self.filter_fqns = mx_job_config.filter_fqns self.config = config self.enabled = True - # pyrefly: ignore [missing-attribute] logger.info(f"MX training active with recipe {mx_job_config.recipe_name}") def convert(self, model: nn.Module): diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 0931fae09f..6956d3298f 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -145,13 +145,10 @@ def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: tokenizer = Tokenizer(bpe_model) # Configure GPT-2 style components for proper space handling - # pyrefly: ignore [read-only] tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( add_prefix_space=False ) - # pyrefly: ignore [read-only] tokenizer.decoder = decoders.ByteLevel() - # pyrefly: ignore [read-only] tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) return tokenizer diff --git a/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py index 2c9742b1fa..8dab20908c 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -277,14 +277,14 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): +class Qwen3VLLMCompatModel(ModelProtocol): """ Qwen3 model with vLLM-compatible implementation. Uses merged gate_up projections and vLLM Flash Attention. """ def __init__(self, model_args: Qwen3ModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index fc846d4098..92c374b8c6 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -404,13 +404,13 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class DeepSeekV3Model(nn.Module, ModelProtocol): +class DeepSeekV3Model(ModelProtocol): """ DeepSeek-V3 Transformer model with attention and feed-forward layers. """ def __init__(self, model_args: DeepSeekV3ModelArgs): - super().__init__() + super().__init__(model_args) self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.register_buffer( diff --git a/torchtitan/models/flux/model/model.py b/torchtitan/models/flux/model/model.py index d0f5592871..1f661c074b 100644 --- a/torchtitan/models/flux/model/model.py +++ b/torchtitan/models/flux/model/model.py @@ -21,7 +21,7 @@ from .args import FluxModelArgs -class FluxModel(nn.Module, ModelProtocol): +class FluxModel(ModelProtocol): """ Transformer model for flow matching on sequences. @@ -33,7 +33,7 @@ class FluxModel(nn.Module, ModelProtocol): """ def __init__(self, model_args: FluxModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index aab4d68140..bf99026047 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -133,7 +133,7 @@ def encode( return tokens # pyrefly: ignore [bad-override] - def decode(self, t: List[int]) -> str: + def decode(self, t: list[int]) -> list[str] | str: """ Decode function. This function will not be called. """ diff --git a/torchtitan/models/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py index f97d15369d..0ef4004595 100644 --- a/torchtitan/models/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -52,7 +52,6 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, - # pyrefly: ignore [missing-attribute] torch._higher_order_ops.inductor_compiled_code, } @@ -207,23 +206,17 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( - # pyrefly: ignore [bad-argument-type] input_layouts=(Shard(1), Replicate(), None), - # pyrefly: ignore [bad-argument-type] desired_input_layouts=(Replicate(), Replicate(), None), ), "attention.wq": ColwiseParallel(use_local_output=False), "attention.wk": ColwiseParallel(use_local_output=False), "attention.wv": ColwiseParallel(use_local_output=False), "attention.inner_attention": PrepareModuleInputOutput( - # pyrefly: ignore [bad-argument-type] input_layouts=(Shard(1), Shard(1), Shard(1)), - # pyrefly: ignore [bad-argument-type] desired_input_layouts=(Shard(1), Shard(1), Shard(1)), use_local_input=True, - # pyrefly: ignore [bad-argument-type] output_layouts=(Shard(1), Shard(1)), - # pyrefly: ignore [bad-argument-type] desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False, ), diff --git a/torchtitan/models/gpt_oss/model/model.py b/torchtitan/models/gpt_oss/model/model.py index 9db5818b3f..44e91f6fb6 100644 --- a/torchtitan/models/gpt_oss/model/model.py +++ b/torchtitan/models/gpt_oss/model/model.py @@ -268,13 +268,13 @@ def init_weights(self, buffer_device: torch.device): self.moe.init_weights(self.weight_init_std, buffer_device) -class GptOssModel(nn.Module, ModelProtocol): +class GptOssModel(ModelProtocol): """ GPT-OSS Transformer model with attention and feed-forward layers. """ def __init__(self, model_args: GptOssModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 01f76b71fc..f504cbcb63 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -231,7 +231,6 @@ def apply_tp( # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, - # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index b8fef9bdb7..317bd1eae8 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -422,7 +422,7 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module, ModelProtocol): +class Transformer(ModelProtocol): """ Transformer Module @@ -442,7 +442,7 @@ class Transformer(nn.Module, ModelProtocol): """ def __init__(self, model_args: TransformerModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index c62635e2eb..085300f220 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -301,7 +301,6 @@ def apply_non_moe_tp( # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, - # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index ef17f6e087..2ab55637d3 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -459,7 +459,7 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module, ModelProtocol): +class Transformer(ModelProtocol): """ Transformer Module @@ -479,7 +479,7 @@ class Transformer(nn.Module, ModelProtocol): """ def __init__(self, model_args: TransformerModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 28a3ba3304..b7a21390d7 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -274,7 +274,6 @@ def apply_non_moe_tp( # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, - # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index a6a6e0407c..4aecf0f52a 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -425,7 +425,7 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Qwen3Model(nn.Module, ModelProtocol): +class Qwen3Model(ModelProtocol): """ Qwen3Model Module @@ -445,7 +445,7 @@ class Qwen3Model(nn.Module, ModelProtocol): """ def __init__(self, model_args: Qwen3ModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 712449f2f6..99e4c34dc0 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -6,7 +6,6 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Protocol import torch import torch.nn as nn @@ -41,15 +40,18 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in pass -class ModelProtocol(Protocol): +class ModelProtocol(nn.Module): """Defines the interface for a model class. This is used to enforce that all model classes have some methods that are required by the trainer. + + NOTE: We keep protocol name for backward compatibility even though it is + not a Protocol anymore. """ def __init__(self, model_args: BaseModelArgs) -> None: - pass + super().__init__() @abstractmethod def init_weights(self, buffer_device: torch.device | None = None) -> None: diff --git a/torchtitan/train.py b/torchtitan/train.py index 3c255ccdaa..3d77eeb425 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -148,7 +148,6 @@ def __init__(self, job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) - # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # metrics logging @@ -166,7 +165,6 @@ def __init__(self, job_config: JobConfig): ( model_param_count, self.metrics_processor.num_flops_per_token, - # pyrefly: ignore [bad-argument-type] ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) logger.info( From 69cf2075922195477310f7cf55fcd870ffbf39fb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 20 Jan 2026 14:34:03 -0800 Subject: [PATCH 107/127] [Typing] Remove deprecated enable_symm_mem_for_group (#2260) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ #2260 * #2246 * #2245 This API is not needed any more. pyrefly will give us a warning. --- torchtitan/distributed/tensor_parallel.py | 3 --- torchtitan/models/gpt_oss/infra/parallelize.py | 5 +---- torchtitan/models/qwen3/infra/parallelize.py | 3 --- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py index 59fffc86a2..60101a2862 100644 --- a/torchtitan/distributed/tensor_parallel.py +++ b/torchtitan/distributed/tensor_parallel.py @@ -22,9 +22,6 @@ def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh): "Async TP requires 'model' in --compile.components and --compile.enable" ) - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info("Async TP is enabled") diff --git a/torchtitan/models/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py index 0ef4004595..338092fb7a 100644 --- a/torchtitan/models/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh @@ -241,11 +242,7 @@ def apply_non_moe_tp( ) if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - # pyrefly: ignore [implicit-import] torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index b7a21390d7..63e7a0ba7c 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -278,10 +278,7 @@ def apply_non_moe_tp( ) if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" From 1e8f9acd194a896efeb5f417894c305783c7b138 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 20 Jan 2026 16:51:54 -0800 Subject: [PATCH 108/127] [CP] Refactor Context Parallel to use new PyTorch CP APIs (#2144) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2145 * __->__ #2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require https://github.com/pytorch/pytorch/pull/170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal..LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to https://github.com/pytorch/pytorch/issues/170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. --- torchtitan/components/validate.py | 104 ++++++++-- torchtitan/distributed/context_parallel.py | 192 ++++++++++++++++++ torchtitan/distributed/utils.py | 10 +- torchtitan/experiments/forge/example_train.py | 60 ++++-- .../simple_fsdp/deepseek_v3/parallelize.py | 1 + .../models/deepseek_v3/infra/parallelize.py | 32 ++- torchtitan/models/deepseek_v3/model/args.py | 6 +- torchtitan/models/flux/infra/parallelize.py | 50 ++++- torchtitan/models/flux/model/layers.py | 12 +- torchtitan/models/flux/train.py | 37 ++-- torchtitan/models/flux/validate.py | 61 +++--- torchtitan/models/llama3/infra/parallelize.py | 20 +- torchtitan/models/llama3/model/args.py | 4 +- torchtitan/models/llama3/model/model.py | 7 +- torchtitan/models/llama4/infra/parallelize.py | 16 +- torchtitan/models/llama4/model/args.py | 6 +- torchtitan/models/llama4/model/model.py | 3 - torchtitan/models/qwen3/infra/parallelize.py | 32 ++- torchtitan/models/qwen3/model/model.py | 4 +- torchtitan/train.py | 36 ++-- 20 files changed, 514 insertions(+), 179 deletions(-) create mode 100644 torchtitan/distributed/context_parallel.py diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 3beae2e216..be10d4c6fe 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -6,7 +6,7 @@ from collections.abc import Callable from contextlib import AbstractContextManager -from typing import TypeAlias +from typing import Any, TypeAlias import torch import torch.nn as nn @@ -17,14 +17,12 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader from torchtitan.tools import utils from torchtitan.tools.logging import logger -ValidationContext: TypeAlias = Callable[ - [AbstractContextManager[None] | None], - AbstractContextManager[None], -] +ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]] class BaseValidator: @@ -67,6 +65,7 @@ def __init__( pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer self.parallel_dims = parallel_dims self.loss_fn = loss_fn self.validation_dataloader = build_text_validation_dataloader( @@ -89,6 +88,70 @@ def __init__( "unequal sample counts across ranks when dataset is exhausted." ) + def post_dataloading_process( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + model_parts: list[nn.Module], + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + """ + Post-processing hook after data loading and before model forward pass. + + This method processes the raw data from the dataloader and prepares it for + the model's forward pass. It separates the main input tensor from auxiliary + inputs and constructs additional keyword arguments (e.g., attention masks). + + Args: + input_dict: Dictionary containing tensors from the dataloader. Must + contain an "input" key with the main input tensor. May contain + additional keys for auxiliary inputs (e.g., position ids). + labels: Target labels for the batch. + model_parts: List of model parts for accessing model methods. + + Returns: + A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + - inputs: Main input tensor extracted from input_dict["input"]. + - labels: Target labels (potentially modified by CP sharding). + - extra_inputs: Dict of auxiliary input tensors (all keys except + "input" from input_dict). These are passed to the model forward + but are NOT forwarded across pipeline parallel stages. + - extra_kwargs: Dict of additional keyword arguments for model forward. + These ARE forwarded across pipeline parallel stages. Contains + attention_masks if flex attention is enabled. + + Note: + The distinction between extra_inputs and extra_kwargs is important for + pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, + while extra_inputs are only available to the first stage. + """ + inputs = input_dict["input"] + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + try: + # pyrefly: ignore [not-callable] + extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) + except TypeError: + pass + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + inputs.device, + ) + + return inputs, labels, extra_inputs, extra_kwargs + @torch.no_grad() # pyrefly: ignore [bad-override] def validate( @@ -117,37 +180,36 @@ def validate( self.metrics_processor.ntokens_since_last_log += labels.numel() for k, v in input_dict.items(): input_dict[k] = v.to(device_type) - inputs = input_dict["input"] labels = labels.to(device_type) - optional_context_parallel_ctx = None - if parallel_dims.cp_enabled: - cp_mesh = parallel_dims.get_mesh("cp") - optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) + # Process data (extract inputs, handle attention masks, CP sharding) + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels, model_parts + ) if parallel_dims.pp_enabled: assert self.pp_schedule is not None assert self.pp_has_first_stage is not None assert self.pp_has_last_stage is not None # Pipeline Parallel forward inside eval() call - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.eval( inputs, + **extra_inputs, + **extra_kwargs, target=targets, losses=losses, ) else: - self.pp_schedule.eval(target=targets, losses=losses) + self.pp_schedule.eval( + **extra_kwargs, + target=targets, + losses=losses, + ) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU @@ -160,10 +222,12 @@ def validate( else torch.tensor([-1.0], device=device_type) ) else: - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - predictions = model_parts[0](inputs) + predictions = model_parts[0]( + inputs, **extra_inputs, **extra_kwargs + ) loss = self.loss_fn(predictions, labels) accumulated_losses.append(loss.detach()) diff --git a/torchtitan/distributed/context_parallel.py b/torchtitan/distributed/context_parallel.py new file mode 100644 index 0000000000..b921831e25 --- /dev/null +++ b/torchtitan/distributed/context_parallel.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from typing import Any, cast + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.experimental._attention import ( + _context_parallel_shard, + _ContextParallel, + _enable_context_parallel_dispatcher, + _HeadTailLoadBalancer, +) +from torch.distributed.tensor.parallel import parallelize_module + +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.tools.logging import logger + + +def apply_cp_to_attention_module( + attention_modules: Sequence[nn.Module], + cp_mesh: DeviceMesh, + attention_type: str, +) -> None: + """ + Apply context parallelism to attention modules. + + CP splits the sequence dimension across devices to enable training with + longer sequences. This function applies CP to the provided attention + modules. + + Args: + attention_modules: Sequence of attention modules to apply CP to + cp_mesh: Device mesh for context parallel dimension + attention_type: Type of attention mechanism. Must be one of: + - "sdpa": scaled_dot_product_attention() + - "flex": flex_attention() + - "varlen": varlen_attn() (not yet implemented) + + Raises: + NotImplementedError: If attention_type is "varlen" + """ + # Apply context parallelism to every attention module + # TODO: make seq_dim configurable once the implementation doesn't assume 2 + # internally. + match attention_type: + case "flex": + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX + ) + case "sdpa": + # Enable the DTensor dispatcher to route SDPA operations to the + # Context Parallel implementation. This is required for CP to work + # with SDPA (but not FlexAttention). + # Note: Use _disable_context_parallel_dispatcher() if you need to + # turn this off. In TorchTitan, we currently don't disable the CP + # dispatcher. + _enable_context_parallel_dispatcher() + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA + ) + case "varlen": + raise NotImplementedError( + "Variable-length attention CP is not yet supported" + ) + case _: + raise ValueError( + f"Invalid attention_type '{attention_type}'. " + f"Must be one of: 'sdpa', 'flex', 'varlen'" + ) + + for attention_module in attention_modules: + parallelize_module( + module=attention_module, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + logger.info("Applied Context Parallel to the model") + + +def prepare_context_parallel_input( + inputs: torch.Tensor, + labels: torch.Tensor, + extra_kwargs: dict[str, Any], + cp_mesh: DeviceMesh, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """ + Prepare inputs, labels, and attention masks for Context Parallel forward pass. + + This function prepares tensors for context parallel by: + 1. Creating position indices based on input sequence length + 2. Sharding inputs, labels, and positions across the CP mesh + 3. Sharding attention masks if present + + Args: + inputs: Input tensor of shape [batch_size, seq_len] + labels: Label tensor of shape [batch_size, seq_len] + extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded + cp_mesh: Device mesh for context parallel dimension + device: Device to create position tensor on + + Returns: + Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where: + - sharded_inputs: Inputs sharded along sequence dimension + - sharded_labels: Labels sharded along sequence dimension + - updated_extra_kwargs: Dict with sharded 'positions' and optionally + sharded 'attention_masks' + """ + attention_masks = extra_kwargs.get("attention_masks", None) + positions = torch.arange( + 0, inputs.shape[1], dtype=torch.int32, device=device + ).expand(inputs.shape) + (inputs, labels, positions), attention_masks = cp_shard( + cp_mesh, + (inputs, labels, positions), + attention_masks, + ) + extra_kwargs["positions"] = positions + if attention_masks is not None: + extra_kwargs["attention_masks"] = attention_masks + + return inputs, labels, extra_kwargs + + +def cp_shard( + cp_mesh: DeviceMesh, + inputs: tuple[torch.Tensor, ...], + attention_masks: AttentionMasksType | None, + disable_load_balancer: bool = False, + input_seq_dim: int = 1, +) -> tuple[tuple[torch.Tensor, ...], AttentionMasksType | None]: + """ + Shard inputs and attention masks across the context parallel mesh. + + This function distributes input tensors across devices in the CP mesh + along the sequence dimension. It optionally uses a load balancer to + handle uneven computation workload. Currently, HeadTailLoadBalancer is + used for SDPA + CP, which is the only supported configuration. + + Args: + cp_mesh: Device mesh for context parallel dimension + inputs: Tuple of input tensors to be sharded along the sequence + dimension + attention_masks: Attention masks to be sharded (currently raises + error as FlexAttention CP is not yet supported) + disable_load_balancer: If True, disables load balancing. If False + (default), uses HeadTailLoadBalancer for SDPA to handle uneven + computation workload. + input_seq_dim: Sequence dimension index for sharding. Defaults to 1, + which covers most use cases where tensors have shape + [batch_size, seq_len, ...]. Can be changed by passing a + different value if your tensors use a different sequence + dimension layout. + + Returns: + Tuple of (sharded_inputs, attention_masks) where: + - sharded_inputs: Tuple of input tensors sharded along the + sequence dimension + - attention_masks: Attention masks (currently unchanged/None) + """ + seq_len = inputs[0].size(input_seq_dim) + cp_world_size = cp_mesh.size(0) + if attention_masks is not None: + raise ValueError( + "FlexAttention CP is not supported yet. Will come in the next PR." + ) + else: + # For SDPA, we use the _HeadTailLoadBalancer. + load_balancer = ( + None + if disable_load_balancer + else _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type) + ) + + inputs = cast( + tuple[torch.Tensor, ...], + _context_parallel_shard( + mesh=cp_mesh, + buffers=inputs, + seq_dims=tuple(input_seq_dim for _ in inputs), + load_balancer=load_balancer, + ), + ) + + return inputs, attention_masks diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index cd2d797c3a..2ba9c08422 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -224,23 +224,17 @@ def create_context_parallel_ctx( class TrainContext(Protocol): @abstractmethod - def __call__( - self, - cp_context: contextlib.AbstractContextManager[None] | None = None, - ) -> contextlib.AbstractContextManager[None]: + def __call__(self) -> contextlib.AbstractContextManager[None]: pass def get_train_context(enable_loss_parallel: bool) -> TrainContext: @contextlib.contextmanager - def context(cp_context: contextlib.AbstractContextManager[None] | None = None): + def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if cp_context: - stack.enter_context(cp_context) - yield return context diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 6530eb18bb..be670865b8 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -19,6 +19,7 @@ from torchtitan.components.validate import build_validator from torchtitan.config import JobConfig from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_dataloader from torchtitan.tools import utils from torchtitan.tools.logging import logger @@ -152,42 +153,57 @@ def batch_generator( yield input_dict, labels - def forward_backward_step( + def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: inputs = input_dict["input"] - extra_kwargs = {} - - if getattr(self.model_args, "attn_type", "sdpa") == "flex": - extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + try: + # pyrefly: ignore [not-callable] + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, + extra_inputs=extra_inputs, ) - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.get_mesh("cp"), - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + except TypeError: + pass + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.device, ) - if parallel_dims.cp_enabled - else None + + return inputs, labels, extra_inputs, extra_kwargs + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.step( inputs, + **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -211,10 +227,10 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_kwargs) + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 3e2209dd6a..5e62bd4891 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -89,6 +89,7 @@ def parallelize_deepseekv3( parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 374df44157..19d9f946d2 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -18,6 +18,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ddp @@ -65,7 +66,11 @@ def parallelize_deepseekv3( attn_type = getattr(model.model_args, "attn_type", "sdpa") if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": - raise NotImplementedError("CP support is only supported for SDPA.") + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -87,6 +92,7 @@ def parallelize_deepseekv3( tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, tp_mesh) @@ -127,6 +133,14 @@ def parallelize_deepseekv3( use_deepep=use_deepep, ) + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, + ) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -178,9 +192,6 @@ def parallelize_deepseekv3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -201,6 +212,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + cp_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -239,15 +251,19 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if cp_enabled else None # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), - # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None, None), - desired_input_layouts=(Replicate(), Replicate(), None, None), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index f880a53384..fc63b12210 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -102,7 +102,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.parallelism.context_parallel_degree > 1 and self.attn_type != "sdpa" ): - raise NotImplementedError("CP support is only supported for SDPA.") + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{self.attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index c12e6b0c78..321a73dcc9 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -17,6 +17,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.tools.logging import logger @@ -28,6 +29,9 @@ def parallelize_flux( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + if parallel_dims.cp_enabled: + apply_cp(model, parallel_dims.get_mesh("cp")) + if parallel_dims.fsdp_enabled: names = ( ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] @@ -47,16 +51,6 @@ def parallelize_flux( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - # The attention in Flux does not use causal mask. - # Currently, load_balance must be disabled in order to support Context Parallelism - # in Pytorch's experimental ring attention module - # https://github.com/pytorch/pytorch/blob/v2.9.0/torch/distributed/tensor/experimental/_attention.py#L395 - from torch.distributed.tensor.experimental._attention import _cp_options - - _cp_options.enable_load_balance = False - logger.info("Applied Context Parallel to the model") - return model @@ -134,6 +128,42 @@ def apply_ac(model: nn.Module, ac_config): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") +def apply_cp(model: nn.Module, cp_mesh: DeviceMesh) -> None: + """ + Apply context parallelism to the Flux model. + + Args: + model: The Flux model with double_blocks and single_blocks containing + inner attention modules. + cp_mesh: Device mesh for context parallel dimension + + Note: + - Uses SDPA attention type + - Applies to all inner_attention modules in double_blocks and single_blocks + """ + # Collect all inner_attention modules from the Flux model + attention_modules = [] + + # pyrefly: ignore [not-iterable] + for double_block in model.double_blocks: + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.img_attn.inner_attention) + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.txt_attn.inner_attention) + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.inner_attention) + + # pyrefly: ignore [not-iterable] + for single_block in model.single_blocks: + # pyrefly: ignore [missing-attribute] + attention_modules.append(single_block.inner_attention) + + # Apply CP using the shared implementation (always uses SDPA for Flux) + apply_cp_to_attention_module(attention_modules, cp_mesh, "sdpa") + + logger.info("Applied Context Parallel to the Flux model") + + def parallelize_encoders( t5_model: nn.Module, clip_model: nn.Module, diff --git a/torchtitan/models/flux/model/layers.py b/torchtitan/models/flux/model/layers.py index 30ba52d3a3..6d0e696dd9 100644 --- a/torchtitan/models/flux/model/layers.py +++ b/torchtitan/models/flux/model/layers.py @@ -13,6 +13,8 @@ from einops import rearrange from torch import nn, Tensor +from torchtitan.models.attention import ScaledDotProductAttentionWrapper + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 @@ -124,6 +126,7 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.qkv, self.proj): @@ -136,7 +139,7 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.inner_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") x = self.proj(x) return x @@ -206,6 +209,8 @@ def __init__( nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) + self.inner_attention = ScaledDotProductAttentionWrapper() + def init_weights(self): # initialize all the nn.Linear submodules for layer in ( @@ -257,7 +262,7 @@ def forward( v = torch.cat((txt_v, img_v), dim=2) q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -308,6 +313,7 @@ def __init__( self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.linear1, self.linear2): @@ -329,7 +335,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: # compute attention q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") # compute activation in mlp stream, cat again and run second linear layer diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 91f2c9b5c2..3348cfe385 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -136,29 +136,24 @@ def forward_backward_step( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = None + # Apply CP sharding if enabled if self.parallel_dims.cp_enabled: - cp_mesh = self.parallel_dims.get_mesh("cp") - optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + from torchtitan.distributed.context_parallel import cp_shard + + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = cp_shard( + self.parallel_dims.get_mesh("cp"), + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + disable_load_balancer=True, ) - with self.train_context(optional_context_parallel_ctx): + + with self.train_context(): with self.maybe_enable_amp: latent_noise_pred = model( img=latents, diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 9deb12a195..077134f81d 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -61,6 +61,7 @@ def __init__( pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer self.parallel_dims = parallel_dims self.loss_fn = loss_fn # pyrefly: ignore [missing-attribute] @@ -220,41 +221,35 @@ def validate( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = None - if parallel_dims.cp_enabled: - cp_mesh = parallel_dims.get_mesh("cp") - optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + # Apply CP sharding if enabled + if parallel_dims.cp_enabled: + from torchtitan.distributed.context_parallel import cp_shard + + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = cp_shard( + parallel_dims.get_mesh("cp"), + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + disable_load_balancer=True, + ) + + with self.validation_context(): + with self.maybe_enable_amp: + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=timesteps, ) - with self.validation_context(optional_context_parallel_ctx): - with self.maybe_enable_amp: - latent_noise_pred = model( - img=latents, - img_ids=latent_pos_enc, - txt=t5_encodings, - txt_ids=text_pos_enc, - y=clip_encodings, - timesteps=timesteps, - ) - - loss = self.loss_fn(latent_noise_pred, target) + loss = self.loss_fn(latent_noise_pred, target) del noise, target, latent_noise_pred, latents diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index f504cbcb63..fde8c40ed3 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -26,6 +26,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -89,9 +90,19 @@ def parallelize_llama( tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + cp_enabled=parallel_dims.cp_enabled, ) maybe_enable_async_tp(job_config, tp_mesh) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, + ) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -131,9 +142,6 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -154,6 +162,7 @@ def apply_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + cp_enabled: bool = False, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -208,7 +217,10 @@ def apply_tp( layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama3. + # TODO: https://github.com/pytorch/torchtitan/pull/2149 would fix this + # inconsistency. "attention": prepare_module_input( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 79e97dab4c..38b8bc4df5 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -59,7 +59,9 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: and self.attn_type != "sdpa" ): raise NotImplementedError( - "CP support for FlexAttention is still in progress." + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{self.attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 317bd1eae8..0cb02e12a5 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -499,9 +499,6 @@ def init_weights( def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, self.model_args.rope_scaling_args, @@ -551,9 +548,7 @@ def get_attention_masks( input_batch, tokenizer.eos_id ) case _: - raise NotImplementedError( - "Only varlen and flex attn masks are supported" - ) + raise TypeError("Only varlen and flex attn masks are supported") def forward( self, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 085300f220..f3b5a0db20 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -26,6 +26,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.dual_pipe_v import ( DualPipeExpertParallel, get_dual_pipe_v_flag, @@ -143,6 +144,15 @@ def parallelize_llama( use_deepep=use_deepep, ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, + ) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -198,9 +208,6 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -272,7 +279,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama4. "attention": prepare_module_input( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 3520e7e519..a93030d82f 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -80,7 +80,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.parallelism.context_parallel_degree > 1 and self.attn_type != "sdpa" ): - raise NotImplementedError("CP support is only supported for SDPA.") + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{self.attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 2ab55637d3..d96d66aa2f 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -536,9 +536,6 @@ def init_weights( def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, self.model_args.rope_scaling_args, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 63e7a0ba7c..be963beecd 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -24,6 +24,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( @@ -67,7 +68,11 @@ def parallelize_qwen3( attn_type = getattr(model.model_args, "attn_type", "sdpa") if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": - raise NotImplementedError("CP support is only supported for SDPA.") + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -97,6 +102,7 @@ def parallelize_qwen3( loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + cp_enabled=parallel_dims.cp_enabled, ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: @@ -111,6 +117,14 @@ def parallelize_qwen3( dual_pipe_v=dual_pipe_v, ) + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, + ) + if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -158,9 +172,6 @@ def parallelize_qwen3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -188,6 +199,7 @@ def apply_non_moe_tp( loss_parallel: bool, enable_float8_tensorwise_tp: bool, enable_async_tp: bool, + cp_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -237,15 +249,19 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if cp_enabled else None # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), - # NOTE: when the fourth argument (positions) is not None, its input layout - # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None, None), - desired_input_layouts=(Replicate(), Replicate(), None, None), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 4aecf0f52a..1d769ade59 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -553,9 +553,7 @@ def get_attention_masks( input_batch, tokenizer.eos_id ) case _: - raise NotImplementedError( - "Only varlen and flex attn masks are supported" - ) + raise TypeError("Only varlen and flex attn masks are supported") def forward( self, diff --git a/torchtitan/train.py b/torchtitan/train.py index 3d77eeb425..10ce82ca08 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -27,6 +27,7 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -469,6 +470,15 @@ def post_dataloading_process( extra_inputs=extra_inputs, ) + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.device, + ) + return inputs, labels, extra_inputs, extra_kwargs def forward_backward_step( @@ -480,30 +490,10 @@ def forward_backward_step( inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( input_dict, labels ) - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage - cp_buffers: list[torch.Tensor] = [inputs, labels] - cp_seq_dims = [1, 1] - if hasattr(model_parts[0], "freqs_cis"): - for m in model_parts: - assert isinstance(m.freqs_cis, torch.Tensor) - cp_buffers.append(m.freqs_cis) - cp_seq_dims += [0 for _ in model_parts] - - optional_context_parallel_ctx = None - if parallel_dims.cp_enabled: - cp_mesh = parallel_dims.get_mesh("cp") - optional_context_parallel_ctx = dist_utils.create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=cp_buffers, - cp_seq_dims=cp_seq_dims, - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) @@ -536,8 +526,8 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 + assert len(model_parts) == 1 + with self.train_context(): with self.maybe_enable_amp: pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) From 0a2107f984639e23a0e5b07fc278785345f03b73 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 20 Jan 2026 23:33:26 -0800 Subject: [PATCH 109/127] [CP] Enable FlexCP for llama3 (#2145) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ #2145 Summary: Continue the previous PR, this PR enable FlexAttention + CP for llama3. FlexCP will use PTRRLoadBalancer. Note that this PR requires https://github.com/pytorch/pytorch/pull/170201 --- torchtitan/components/validate.py | 1 + torchtitan/config/job_config.py | 15 +++ torchtitan/distributed/context_parallel.py | 98 ++++++++++++++----- torchtitan/experiments/forge/example_train.py | 1 + torchtitan/models/attention.py | 4 +- torchtitan/models/flux/train.py | 2 +- torchtitan/models/flux/validate.py | 2 +- torchtitan/models/llama3/model/args.py | 6 +- torchtitan/train.py | 1 + 9 files changed, 102 insertions(+), 28 deletions(-) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index be10d4c6fe..7917ff2fc2 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -148,6 +148,7 @@ def post_dataloading_process( extra_kwargs, self.parallel_dims.get_mesh("cp"), inputs.device, + self.job_config.parallelism.context_parallel_load_balancer, ) return inputs, labels, extra_inputs, extra_kwargs diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 108c38efba..b3a24c7847 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -420,6 +420,21 @@ class Parallelism: context_parallel_degree: int = 1 """Context parallelism degree. 1 means disabled.""" + context_parallel_load_balancer: str | None = "headtail" + """ + Load balancer type for context parallelism. Options: + - "headtail": Use HeadTailLoadBalancer for SDPA + - "ptrr": Use PTRRLoadBalancer for FlexAttention + - None: Disable load balancing + """ + + def __post_init__(self): + if self.context_parallel_load_balancer == "": + raise ValueError( + "context_parallel_load_balancer cannot be an empty string. " + "Use None to disable load balancing." + ) + context_parallel_rotate_method: Literal["allgather", "alltoall"] = "allgather" """ The collective to use in context parallel SDPA for kv shards exchange. diff --git a/torchtitan/distributed/context_parallel.py b/torchtitan/distributed/context_parallel.py index b921831e25..c9eb897cda 100644 --- a/torchtitan/distributed/context_parallel.py +++ b/torchtitan/distributed/context_parallel.py @@ -15,8 +15,10 @@ _ContextParallel, _enable_context_parallel_dispatcher, _HeadTailLoadBalancer, + _PTRRLoadBalancer, ) from torch.distributed.tensor.parallel import parallelize_module +from torch.nn.attention.flex_attention import BlockMask from torchtitan.protocols.model import AttentionMasksType from torchtitan.tools.logging import logger @@ -90,6 +92,7 @@ def prepare_context_parallel_input( extra_kwargs: dict[str, Any], cp_mesh: DeviceMesh, device: torch.device, + load_balancer_type: str | None = "headtail", ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: """ Prepare inputs, labels, and attention masks for Context Parallel forward pass. @@ -105,6 +108,8 @@ def prepare_context_parallel_input( extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded cp_mesh: Device mesh for context parallel dimension device: Device to create position tensor on + load_balancer_type: Type of load balancer to use for sharding. + Options: "headtail", "ptrr", or None. Defaults to "headtail". Returns: Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where: @@ -121,6 +126,7 @@ def prepare_context_parallel_input( cp_mesh, (inputs, labels, positions), attention_masks, + load_balancer_type, ) extra_kwargs["positions"] = positions if attention_masks is not None: @@ -133,29 +139,30 @@ def cp_shard( cp_mesh: DeviceMesh, inputs: tuple[torch.Tensor, ...], attention_masks: AttentionMasksType | None, - disable_load_balancer: bool = False, + load_balancer_type: str | None = "headtail", input_seq_dim: int = 1, ) -> tuple[tuple[torch.Tensor, ...], AttentionMasksType | None]: """ Shard inputs and attention masks across the context parallel mesh. This function distributes input tensors across devices in the CP mesh - along the sequence dimension. It optionally uses a load balancer to - handle uneven computation workload. Currently, HeadTailLoadBalancer is - used for SDPA + CP, which is the only supported configuration. + along the sequence dimension, enabling efficient processing. It optionally + uses a load balancer to handle uneven computation workload. Args: cp_mesh: Device mesh for context parallel dimension inputs: Tuple of input tensors to be sharded along the sequence dimension - attention_masks: Attention masks to be sharded (currently raises - error as FlexAttention CP is not yet supported) - disable_load_balancer: If True, disables load balancing. If False - (default), uses HeadTailLoadBalancer for SDPA to handle uneven - computation workload. + attention_masks: Attention masks to be sharded. Supports None, + BlockMask, or dict[str, BlockMask] + load_balancer_type: Type of load balancer to use. Options: + - "headtail": Use HeadTailLoadBalancer (for SDPA) + - "ptrr": Use PTRRLoadBalancer (for FlexAttention) + - None: Disable load balancing + Defaults to "headtail". input_seq_dim: Sequence dimension index for sharding. Defaults to 1, which covers most use cases where tensors have shape - [batch_size, seq_len, ...]. Can be changed by passing a + [batch_size, seq_len]. Can be changed by passing a different value if your tensors use a different sequence dimension layout. @@ -163,21 +170,45 @@ def cp_shard( Tuple of (sharded_inputs, attention_masks) where: - sharded_inputs: Tuple of input tensors sharded along the sequence dimension - - attention_masks: Attention masks (currently unchanged/None) + - attention_masks: Sharded attention masks (BlockMask or + dict[str, BlockMask]) or None + + Raises: + ValueError: If load_balancer_type is "ptrr" and attention_masks + is None or a dict """ seq_len = inputs[0].size(input_seq_dim) cp_world_size = cp_mesh.size(0) - if attention_masks is not None: - raise ValueError( - "FlexAttention CP is not supported yet. Will come in the next PR." - ) - else: - # For SDPA, we use the _HeadTailLoadBalancer. - load_balancer = ( - None - if disable_load_balancer - else _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type) - ) + + load_balancer = None + if load_balancer_type: + match load_balancer_type: + case "headtail": + # For SDPA, we use the _HeadTailLoadBalancer. + load_balancer = _HeadTailLoadBalancer( + seq_len, cp_world_size, cp_mesh.device_type + ) + case "ptrr": + # For FlexAttention, we use _PTRRLoadBalancer. + # _PTRRLoadBalancer requires attention_masks to be a BlockMask. + # For dict[str, BlockMask], _PTRRLoadBalancer currently doesn't + # support the case where there are multiple masks. + if attention_masks is None or isinstance(attention_masks, dict): + raise ValueError( + "PTRRLoadBalancer requires attention_masks to be a " + "BlockMask, but got None or dict[str, BlockMask]" + ) + if not isinstance(attention_masks, BlockMask): + raise ValueError( + f"PTRRLoadBalancer requires attention_masks to be a " + f"BlockMask, but got {type(attention_masks)}" + ) + load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size) + case _: + raise ValueError( + f"Invalid load_balancer_type '{load_balancer_type}'. " + f"Must be one of: 'headtail', 'ptrr', or None" + ) inputs = cast( tuple[torch.Tensor, ...], @@ -189,4 +220,27 @@ def cp_shard( ), ) + # BlockMask, has shape, [B, H, Q, KV], and we can only shard + # on the Q seq dimension, not KV. + MASK_Q_SEQ_DIM = 2 + if attention_masks is not None: + assert isinstance(attention_masks, (BlockMask, dict[str, BlockMask])) + masks = ( + [attention_masks] + if isinstance(attention_masks, BlockMask) + else list(attention_masks.values()) + ) + masks = _context_parallel_shard( + mesh=cp_mesh, + buffers=masks, + seq_dims=(MASK_Q_SEQ_DIM,) * len(masks), + load_balancer=load_balancer, + ) + attention_masks = cast( + (BlockMask | dict[str, BlockMask]), + masks[0] + if isinstance(attention_masks, BlockMask) + else {k: v for k, v in zip(attention_masks.keys(), masks)}, + ) + return inputs, attention_masks diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index be670865b8..b00ec58a2a 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -180,6 +180,7 @@ def post_dataloading_process( extra_kwargs, self.parallel_dims.get_mesh("cp"), self.device, + self.job_config.parallelism.context_parallel_load_balancer, ) return inputs, labels, extra_inputs, extra_kwargs diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index e1255f1e94..78f317526e 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -14,6 +14,7 @@ from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + _score_mod_signature, BlockMask, create_block_mask, flex_attention, @@ -118,7 +119,8 @@ def forward( k: torch.Tensor, v: torch.Tensor, *, - block_mask: BlockMask, + score_mod: _score_mod_signature | None = None, + block_mask: BlockMask | None = None, scale: float | None = None, return_lse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 3348cfe385..7d85d2b3a1 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -150,7 +150,7 @@ def forward_backward_step( self.parallel_dims.get_mesh("cp"), (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), None, # No attention masks for Flux - disable_load_balancer=True, + load_balancer_type=None, ) with self.train_context(): diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 077134f81d..70dfff4bb3 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -235,7 +235,7 @@ def validate( parallel_dims.get_mesh("cp"), (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), None, # No attention masks for Flux - disable_load_balancer=True, + load_balancer_type=None, ) with self.validation_context(): diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 38b8bc4df5..1ff134a163 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -56,12 +56,12 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: if ( job_config.parallelism.context_parallel_degree > 1 - and self.attn_type != "sdpa" + and self.attn_type == "varlen" ): raise NotImplementedError( - f"Context Parallel only supports SDPA attention. " + f"Context Parallel only supports SDPA and FlexAttention." f"Got attn_type='{self.attn_type}'. " - f"FlexAttention and varlen attention are not supported with CP." + f"Varlen attention is not supported with CP." ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/train.py b/torchtitan/train.py index 10ce82ca08..fc7514cad7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -477,6 +477,7 @@ def post_dataloading_process( extra_kwargs, self.parallel_dims.get_mesh("cp"), self.device, + self.job_config.parallelism.context_parallel_load_balancer, ) return inputs, labels, extra_inputs, extra_kwargs From 8ff9e42a3448d5a5c013ffe86a8fe1b90d7aea68 Mon Sep 17 00:00:00 2001 From: Shuhua Yu <18108279+shuhuayu@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:08:40 -0800 Subject: [PATCH 110/127] [MoE] Fix experts DTensor metadata bug for dcp (#2227) Previously, individual experts are marked as `Replicate` in EP dimension in global `global_device_mesh`. Local experts are first created on `global_device_mesh` and are turned into a 2d tensor using `squeeze(0)`, which only removes the extra dimension, but the remaining metadata `Replicate` is still there. The wrong metadata results in bug when DCP saves `DTensor`. This PR fixes this bug by: 1. Use a sub-mesh that excludes expert dimension, i.e., dim 0. 2. When sub-mesh is empty, use plain tensor instead of `DTensor`. --- .../deepseek_v3/model/state_dict_adapter.py | 6 +- .../models/qwen3/model/state_dict_adapter.py | 6 +- torchtitan/models/utils.py | 86 ++++++++++++------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 7fd6743600..1970c7a161 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -116,6 +116,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key ] = value.placements self.grouped_expert_weight_shape[abstract_key] = value.shape + self.grouped_expert_weight_mesh[abstract_key] = value.device_mesh # Split GroupedExperts weight to local individual expert weights local_expert_fqn = self._get_local_experts_weights( @@ -179,12 +180,13 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: int(expert_num) ] = value - if isinstance(value, DTensor): + # Use stored metadata to decide path (online vs offline) + # Online mode: local_experts_indices was populated during to_hf() + if titan_abstract_key in self.local_experts_indices: stacked_value = self._concatenate_expert_weights_dtensor( expert_weights_by_layer, titan_abstract_key, layer_num, - value.device_mesh, ) else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( diff --git a/torchtitan/models/qwen3/model/state_dict_adapter.py b/torchtitan/models/qwen3/model/state_dict_adapter.py index 8dfe4d5aa7..1fcd51081c 100644 --- a/torchtitan/models/qwen3/model/state_dict_adapter.py +++ b/torchtitan/models/qwen3/model/state_dict_adapter.py @@ -73,6 +73,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key ] = value.placements self.grouped_expert_weight_shape[abstract_key] = value.shape + self.grouped_expert_weight_mesh[abstract_key] = value.device_mesh # Split GroupedExperts weight to local individual expert weights local_expert_fqn = self._get_local_experts_weights( @@ -151,12 +152,13 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: int(expert_num) ] = value - if isinstance(value, DTensor): + # Use stored metadata to decide path (online vs offline) + # Online mode: local_experts_indices was populated during to_hf() + if titan_abstract_key in self.local_experts_indices: stacked_value = self._concatenate_expert_weights_dtensor( expert_weights_by_layer, titan_abstract_key, layer_num, - value.device_mesh, ) else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index 5bf73fbb7e..1e5befbf70 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -37,6 +37,7 @@ def __init__( # Store metadata for GroupedExperts <-> individual experts conversion self.grouped_expert_weight_placements = {} # {titan_abstract_key: placements} self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} + self.grouped_expert_weight_mesh = {} # {titan_abstract_key: device_mesh} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} def _calculate_strided_shard_shard_indices( @@ -96,7 +97,7 @@ def _caculate_indices_from_placements( dim_size: int, dtensor_placements: tuple, device_mesh: DeviceMesh, - ) -> tuple[int | None, int | None]: + ) -> tuple[int, int]: mesh_names = [] dim_i_placements = [] @@ -110,7 +111,7 @@ def _caculate_indices_from_placements( dim_i_placements.append(placement) # Calculate local expert indices based on sharding strategy - start_index, end_index = None, None + start_index, end_index = 0, dim_size if len(dim_i_placements) == 2: # Handle StridedShard(i) + Shard(i) case assert isinstance( @@ -149,8 +150,8 @@ def _caculate_indices_from_placements( end_index = start_index + block_size elif len(dim_i_placements) == 0: - # No need to split on this dimension - return start_index, end_index + # No sharding on this dimension means all elements are local + pass else: raise NotImplementedError( @@ -180,7 +181,7 @@ def _get_local_experts_weights( grouped_expert_weight: DTensor containing all experts' weights Returns: - Dictionary mapping individual expert keys to their DTensor weights + Dictionary mapping individual expert keys to their DTensor or plain tensor weights """ # pyrefly: ignore [missing-attribute] device_mesh = grouped_expert_weight.device_mesh @@ -195,33 +196,44 @@ def _get_local_experts_weights( dtensor_placements=dtensor_placements, device_mesh=device_mesh, ) - assert ( - start_index is not None and end_index is not None - ), "Start index and end index can not be None on dim-0!" # Step 2: Store indices for potential future use in from_hf() self.local_experts_indices[titan_abstract_key] = (start_index, end_index) - # Step 3: Create new placements for individual expert weights - new_placements = [] + # Step 3: Identify mesh dimensions that shard on dim-0 (expert dimension) + # exclude expert dimension + # and build new sub-mesh/placements for individual expert weights + sub_mesh_names = [] + sub_placements = [] + for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] - if placement.dim == 0: - # Convert dim-0 sharding to replication for individual experts - new_placements.append(Replicate()) + if isinstance(placement, Replicate): + # Replicate (hybrid) doesn't shard any dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append(Replicate()) + elif isinstance(placement, (Shard, _StridedShard)) and placement.dim == 0: + # Shards on expert dim, exclude from sub-mesh + pass elif isinstance(placement, Shard): - # Keep other shard dimensions (individual expert weight has 2D) - new_placements.append(Shard(placement.dim)) + # Shards on non-expert dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append(Shard(placement.dim)) elif isinstance(placement, _StridedShard): - # Keep strided shard with same parameters - new_placements.append( + # Strided shard on non-expert dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append( # pyrefly: ignore [unexpected-positional-argument] _StridedShard(placement.dim, placement.split_factor) ) else: raise ValueError(f"Unsupported placement type: {type(placement)}") - # Step 4: Create individual expert DTensors + # Step 4: Create sub-mesh excluding dim-0 sharding dimensions + # If all mesh dimensions were sharding on dim-0, sub_mesh will be None (use plain tensors) + sub_mesh = device_mesh[tuple(sub_mesh_names)] if sub_mesh_names else None + + # Step 5: Create individual expert tensors assert isinstance( grouped_expert_weight, DTensor ), "Expected DTensor for grouped expert weight" @@ -240,15 +252,21 @@ def _get_local_experts_weights( expert_key = abstract_key.format(layer_id, expert_id) local_expert_index = expert_id - start_index - # Extract individual expert weight and add batch dimension temporarily - expert_weight = local_grouped_weights[local_expert_index, :, :].unsqueeze(0) - - # Create DTensor and remove batch dimension (experts dimension is removed) - expert_dtensor = DTensor.from_local( - expert_weight, device_mesh, new_placements, run_check=False - ).squeeze(0) - - local_expert_tensors[expert_key] = expert_dtensor + if sub_mesh is None: + # Extract individual expert weight (2D) as plain tensor + expert_weight = local_grouped_weights[local_expert_index, :, :] + else: + # Use slicing and unsqueeze get a 3D tensor, then create DTensor and squeeze + expert_weight_3d = local_grouped_weights[ + local_expert_index, :, : + ].unsqueeze(0) + expert_weight = DTensor.from_local( + expert_weight_3d, + sub_mesh, + sub_placements, + run_check=False, + ).squeeze(0) + local_expert_tensors[expert_key] = expert_weight return local_expert_tensors @@ -257,7 +275,6 @@ def _concatenate_expert_weights_dtensor( expert_weights_by_layer: dict[str, dict[str, dict[int, torch.Tensor]]], abstract_key: str, layer_num: str, - device_mesh: DeviceMesh, ) -> torch.Tensor | None: """ Args: @@ -272,7 +289,6 @@ def _concatenate_expert_weights_dtensor( Used to collect individual expert weights before concatenating them into GroupedExperts. abstract_key: TorchTitan templage key with {} placeholders for layer and expert IDs layer_num: Layer identifier - device_mesh: DeviceMesh for the target GroupedExperts weight DTensor Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None @@ -288,17 +304,21 @@ def _concatenate_expert_weights_dtensor( sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] - # pyrefly: ignore [missing-attribute] - local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor + + # Stack experts - result may be DTensor or plain tensor depending on sub_mesh + local_tensor = torch.stack(sorted_experts, dim=0) + if isinstance(local_tensor, DTensor): + local_tensor = local_tensor._local_tensor assert ( abstract_key in self.grouped_expert_weight_placements and abstract_key in self.grouped_expert_weight_shape - ), "GroupedExperts weight metadata (placements, shape) can not be None!" + and abstract_key in self.grouped_expert_weight_mesh + ), "GroupedExperts weight metadata (placements, shape, mesh) can not be None!" stacked_dtensor = DTensor.from_local( local_tensor, - device_mesh, + self.grouped_expert_weight_mesh[abstract_key], self.grouped_expert_weight_placements[abstract_key], run_check=False, ) From 3263b153bd70ca979cdb7f4ba3991d72c55f19a5 Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:50:23 -0600 Subject: [PATCH 111/127] Update GRPO.md Added installation of transformers package and updated sbatch script instructions. --- GRPO.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/GRPO.md b/GRPO.md index 93653052fa..9c3e5a9c1b 100644 --- a/GRPO.md +++ b/GRPO.md @@ -4,6 +4,8 @@ GRPO instructions ## Installation instructions ```shell +mkdir logs +chmod g+rw ./logs pip install uv uv pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu129 uv pip install -r requirements.txt @@ -12,8 +14,23 @@ export VLLM_COMMIT=2918c1b49c88c29783c86f78d2c4221cb9622379 uv pip install vllm torch==2.9.0 --torch-backend=cu129 --prerelease=allow --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} --extra-index-url https://download.pytorch.org/whl/cu129 pip install flashinfer-python==0.4.1 flashinfer-cubin==0.4.1 pip install flashinfer-jit-cache==0.4.1 --index-url https://flashinfer.ai/whl/cu129 +pip install transformers==4.57.1 ``` ## Configuration instructions see `torchtitan/grop/configs/qwen25-7b-math.toml` for good initial values + +## sbatch script + +`online_multinode_vllm.slurm` contains some paths to edit, +- TRAIN_PATH - where this is installed on the cluster +- TRAIN_ENV - if you don't init the venv to .venv, this needs to be changed to that venv +- VLLM_ENV - same as TRAIN_ENV unless you're doing something different +- API_ENV - atropos venv + +One that's done, you can do something like +```bash +sbatch --export=ALL,CONFIG_FILE=/home/dakota/github/torchtitan/torchtitan/grpo/configs/qwen25-7b-math.toml,MODEL_NAME=Qwen/Qwen2.5-7B,PYTHON_SCRIPT=/home/dakota/github/atropos/environments/math_server_zero.py,WANDB_PROJECT=qwen7b_debug online_multinode_vllm.slurm +``` +to launch a run From 5621112d056937ff233d6ae7d104db3a4f33abc7 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Fri, 23 Jan 2026 18:23:12 -0500 Subject: [PATCH 112/127] [varlen_attn] change is_causal to window_size (#2267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `is_causal` flag has been deprecated in `varlen_attn`, use `window_size = [-1, 0]` instead, see [this PR](https://github.com/pytorch/pytorch/pull/172245) *Test* Screenshot 2026-01-21 at 11 55
59 AM --- torchtitan/models/attention.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 78f317526e..83eef080fd 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -84,8 +84,18 @@ def forward( cu_seq_k, max_q, max_k, - is_causal=True, scale=scale, + # window_size=(left, right) controls the attention window relative to each + # query position. 'left' is how many tokens before the query to attend to, + # and 'right' is how many tokens after. A value of -1 means unlimited. + # + # This replaces the is_causal flag: + # - (-1, 0): Causal attention - each token attends to all previous tokens + # and itself, but no future tokens. Equivalent to is_causal=True. + # - (-1, -1): Full bidirectional attention (no masking). Equivalent to + # is_causal=False. + # - (W, 0): Sliding window causal - attend to at most W previous tokens. + window_size=(-1, 0), ) From 81f5a5a9f9a59f0bbf47d015166a8d24e98e1c46 Mon Sep 17 00:00:00 2001 From: akashveramd Date: Sat, 24 Jan 2026 14:35:44 -0800 Subject: [PATCH 113/127] Add ROCm CI support for simple fsdp experiments test (#2220) In this PR we added ROCm CI support for simple fsdp experiments test. --- .../integration_test_8gpu_simple_fsdp.yaml | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml index 9a1a0a2866..302cd46555 100644 --- a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml +++ b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml @@ -3,6 +3,8 @@ name: SimpleFSDP 8 GPU Integration Tests on: push: branches: [ main ] + tags: + - ciflow/8gpu/* paths: - 'torchtitan/experiments/simple_fsdp/**' - '.github/workflows/integration_test_8gpu_simple_fsdp.yaml' @@ -22,18 +24,30 @@ defaults: run: shell: bash -l -eo pipefail {0} +permissions: + id-token: write + contents: read + jobs: + # Step 1: Dynamically compute the matrix based on conditions + set-matrix: + uses: ./.github/workflows/set-matrix.yaml + + # Step 2: Use the dynamic matrix in the build-test job build-test: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + needs: set-matrix + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.set-matrix.outputs.matrix) }} with: - runner: linux.g5.48xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: "12.6" - # This image is faster to clone than the default, but it lacks CC needed by triton - # (1m25s vs 2m37s). - docker-image: torchtitan-ubuntu-20.04-clang12 + runner: ${{ matrix.runner }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + docker-image: ${{ matrix.docker-image }} repository: pytorch/torchtitan upload-artifact: outputs + timeout: 45 script: | set -eux @@ -47,11 +61,13 @@ jobs: pip config --user set global.progress_bar off - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + python -m pip install --force-reinstall --pre torch --index-url ${{ matrix.index-url }} + + sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" + sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" - mkdir artifacts-to-be-uploaded - python -m torchtitan.experiments.simple_fsdp.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 + python -m torchtitan.experiments.simple_fsdp.tests.integration_tests $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 8 # Run the numerics unit tests of SimpleFSDP torchrun --nproc-per-node=8 -m pytest torchtitan/experiments/simple_fsdp/tests/test_numerics.py -v - rm -rf artifacts-to-be-uploaded/*/checkpoint + rm -rf $RUNNER_TEMP/artifacts-to-be-uploaded/*/checkpoint From 865ebb87a3ae543d95ad0174187e7dd2b19a192d Mon Sep 17 00:00:00 2001 From: emozilla Date: Wed, 28 Jan 2026 05:59:02 +0000 Subject: [PATCH 114/127] context parallel support in dsv3 and qwen3 --- torchtitan/models/deepseek_v3/infra/parallelize.py | 6 +++--- torchtitan/models/qwen3/infra/parallelize.py | 6 +++--- torchtitan/train.py | 9 ++------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 45561314d5..da81d60c0c 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -65,11 +65,11 @@ def parallelize_deepseekv3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": raise NotImplementedError( - f"Context Parallel only supports SDPA attention. " + f"Context Parallel only supports SDPA and FlexAttention." f"Got attn_type='{attn_type}'. " - f"FlexAttention and varlen attention are not supported with CP." + f"Varlen attention is not supported with CP." ) if parallel_dims.tp_enabled: diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 24a942b83e..fd621e9883 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -65,11 +65,11 @@ def parallelize_qwen3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": raise NotImplementedError( - f"Context Parallel only supports SDPA attention. " + f"Context Parallel only supports SDPA and FlexAttention." f"Got attn_type='{attn_type}'. " - f"FlexAttention and varlen attention are not supported with CP." + f"Varlen attention is not supported with CP." ) model_compile_enabled = ( diff --git a/torchtitan/train.py b/torchtitan/train.py index 1f2e3068f9..15e81cd018 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -511,14 +511,9 @@ def post_dataloading_process( inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # Map position_ids from dataloader to positions expected by model forward - if "position_ids" in extra_inputs: - extra_inputs["positions"] = extra_inputs.pop("position_ids") - - # For arguments, like attention_masks, we have to put them in a separate - # dict as extra_inputs are not forwarded to other stages in PP, but - # extra_kwargs are. extra_kwargs: dict[str, Any] = {} + if "position_ids" in extra_inputs: + extra_kwargs["positions"] = extra_inputs.pop("position_ids") attn_type = getattr(self.model_args, "attn_type", "sdpa") if attn_type in ["flex", "varlen"]: From 2ad47cb1409fbb567a13b53cc8ede726435239ce Mon Sep 17 00:00:00 2001 From: emozilla Date: Wed, 21 Jan 2026 17:54:47 +0000 Subject: [PATCH 115/127] fast path for initing bfloat16 params on cpu --- torchtitan/models/deepseek_v3/model/model.py | 10 +-- torchtitan/models/moe/__init__.py | 2 +- torchtitan/models/moe/moe.py | 65 ++++++++++++++++---- 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3d639d3911..d99b4b3247 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -21,7 +21,7 @@ get_block_causal_mask_mod_by_seq_lens, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import build_moe, FeedForward +from torchtitan.models.moe import build_moe, fast_init_trunc_normal_, fast_init_normal_, FeedForward from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -379,8 +379,8 @@ def init_weights(self, init_std: float): linear_list.append(self.wq) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.wo.weight, mean=0.0, std=init_std) self.kv_norm.reset_parameters() if self.q_lora_rank > 0: @@ -510,7 +510,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: with torch.device(buffer_device): self.freqs_cis = precompute_freqs_cis(self.model_args) if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: # pyrefly: ignore [not-callable] @@ -520,7 +520,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 if self.output is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.output.weight, mean=0.0, std=final_out_std, diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index 4acfbc21a8..dff7afd537 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -6,4 +6,4 @@ from .moe import build_moe, ExpertRoutingHistogram, FeedForward, MoE, MoEArgs -__all__ = ["FeedForward", "MoE", "MoEArgs", "build_moe", "ExpertRoutingHistogram"] +__all__ = ["FeedForward", "MoE", "MoEArgs", "build_moe", "ExpertRoutingHistogram", "fast_init_trunc_normal_", "fast_init_normal_"] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 75af36f44f..8f0be41be1 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -27,6 +27,45 @@ def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 +def fast_init_trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + """ + Fast truncated normal initialization that handles bfloat16 tensors on CPU. + + When tensors are bfloat16 on CPU, nn.init.trunc_normal_ is extremely slow + because CPUs don't have native bfloat16 support. This function temporarily + converts to float32 for the initialization, then converts back. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + # Initialize in float32 for CPU performance + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.trunc_normal_(temp, mean=mean, std=std, a=a, b=b) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) + + +def fast_init_normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0 +) -> None: + """ + Fast normal initialization that handles bfloat16 tensors on CPU. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.normal_(temp, mean=mean, std=std) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.normal_(tensor, mean=mean, std=std) + + @dataclass class MoEArgs: num_experts: int = 8 @@ -165,9 +204,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=init_std) # NOTE: keeping this for-loop implementation for comparison @@ -381,9 +420,9 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) def _groupmm(x, w, offs): @@ -527,12 +566,12 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) - nn.init.trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) nn.init.zeros_(self.w1_lora_b) nn.init.zeros_(self.w2_lora_b) nn.init.zeros_(self.w3_lora_b) @@ -711,7 +750,7 @@ def init_weights(self, init_std: float, n_layers: int): # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss # when CPU offload is enabled with 3+ GPUs. - nn.init.normal_(self.gate.weight, mean=0.0, std=1.0) + fast_init_normal_(self.gate.weight, mean=0.0, std=1.0) # Normalize rows in-place with torch.no_grad(): @@ -990,7 +1029,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i if self.shared_experts is not None: self.shared_experts.init_weights(init_std) if self.shared_gate is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.shared_gate.weight, mean=0.0, std=moe_init_std(self.shared_gate.weight.shape[1], n_layers), From 81e54a413943adbaaf7d75ad137a56cb9ba93957 Mon Sep 17 00:00:00 2001 From: emozilla Date: Thu, 22 Jan 2026 20:45:54 +0000 Subject: [PATCH 116/127] add reference for init scheme --- torchtitan/models/moe/moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 8f0be41be1..2fc040cbe2 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -23,6 +23,7 @@ class ExpertRoutingHistogram: counts: list[float] +# see https://arxiv.org/pdf/2310.10837 def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 @@ -746,6 +747,10 @@ def forward( return top_scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float, n_layers: int): + # Init gate with each row normalized + # From "Approximating Two-Layer Feedforward Networks for Efficient Transformers" + # https://arxiv.org/pdf/2310.10837 + # NOTE: Must use in-place operations here. When FSDP wraps parameters as # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss From f04236dbac8a67770370d43911ba06b4928ea100 Mon Sep 17 00:00:00 2001 From: emozilla Date: Fri, 23 Jan 2026 07:24:53 +0000 Subject: [PATCH 117/127] overlapped cpu offload muon --- torchtitan/experiments/dion_optimizer/muon.py | 575 +++++++++++++----- 1 file changed, 434 insertions(+), 141 deletions(-) diff --git a/torchtitan/experiments/dion_optimizer/muon.py b/torchtitan/experiments/dion_optimizer/muon.py index 432ab1399f..0ac1602f71 100644 --- a/torchtitan/experiments/dion_optimizer/muon.py +++ b/torchtitan/experiments/dion_optimizer/muon.py @@ -609,27 +609,17 @@ def muon_update_batch_dim_sharded_async( - This is mathematically equivalent to orthogonalizing each expert's weights independently This function processes all params locally without all-to-all or all-gather. + + Optimized for CPU offloading with: + - Double-buffered CUDA streams to overlap transfer and compute + - Batched Newton-Schulz for fewer kernel launches + - Single sync point at end (no intermediate cuda.synchronize()) """ - U = muon_update_pre_orthogonalize( - G=G, - M=M, - momentum=momentum, - nesterov=nesterov, - ) - - # Orthogonalize each tensor locally - # Newton-Schulz treats dim 0 as batch, processing each slice independently - U = [ - muon_update_newton_schulz( - u, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) - for u in U - ] + # Check if we need CPU offloading (tensors are on CPU) + original_device = G[0].device + needs_gpu_transfer = original_device.type != "cuda" - # Compute scaled learning rate + # Compute scaled learning rate upfront # Use the first tensor's shape (they should all be the same shape within a batch) if adjust_lr is None: adjusted_lr = lr @@ -640,16 +630,132 @@ def muon_update_batch_dim_sharded_async( else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=X, - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if needs_gpu_transfer: + # PIPELINED MODE: Double-buffered streams for maximum overlap + # Timeline: transfer[i+1] overlaps with compute[i] overlaps with writeback[i-1] + cuda_device = torch.device("cuda") + dtype = M[0].dtype + n_tensors = len(X) + + # Mini-batch size for batched Newton-Schulz (fewer kernel launches) + BATCH_SIZE = 4 + + # Create streams: one for H2D transfers, one for compute, one for D2H transfers + h2d_stream = torch.cuda.Stream() + compute_stream = torch.cuda.Stream() + d2h_stream = torch.cuda.Stream() + + # Double buffer: prefetch next batch while computing current + prefetch_data = None # Will hold (g_batch, m_batch, x_batch, indices) for next iteration + + def prefetch_batch(start_idx): + """Prefetch a batch of tensors to GPU (non-blocking).""" + end_idx = min(start_idx + BATCH_SIZE, n_tensors) + indices = list(range(start_idx, end_idx)) + with torch.cuda.stream(h2d_stream): + g_batch = [G[i].to(dtype=dtype).to(cuda_device, non_blocking=True) for i in indices] + m_batch = [M[i].to(cuda_device, non_blocking=True) for i in indices] + x_batch = [X[i].to(cuda_device, non_blocking=True) for i in indices] + return (g_batch, m_batch, x_batch, indices) + + def compute_batch(g_batch, m_batch, x_batch, indices): + """Compute momentum update and Newton-Schulz on GPU.""" + with torch.cuda.stream(compute_stream): + # Wait for H2D transfer to complete (lightweight stream sync) + compute_stream.wait_stream(h2d_stream) + + u_batch = [] + for j in range(len(indices)): + g_gpu, m_gpu = g_batch[j], m_batch[j] + # Update momentum: M = mu * M + G + m_gpu.mul_(momentum) + m_gpu.add_(g_gpu) + # Compute U + if nesterov: + u_gpu = m_gpu * momentum + g_gpu + else: + u_gpu = m_gpu.clone() + u_batch.append(u_gpu.to(dtype=torch.bfloat16)) + + # Batched Newton-Schulz: stack same-shape tensors for single kernel + if len(u_batch) > 1 and all(u.shape == u_batch[0].shape for u in u_batch): + u_stacked = torch.stack(u_batch, dim=0) + u_stacked = muon_update_newton_schulz(u_stacked, newton_schulz_func, flatten, epsilon) + u_batch = list(u_stacked.unbind(0)) + else: + u_batch = [muon_update_newton_schulz(u, newton_schulz_func, flatten, epsilon) for u in u_batch] + + # Apply weight decay and update + for j in range(len(indices)): + x_batch[j].mul_(1 - lr * weight_decay) + x_batch[j].sub_(u_batch[j] * adjusted_lr) + + return m_batch, x_batch + + def writeback_batch(m_batch, x_batch, indices): + """Write results back to CPU (non-blocking).""" + with torch.cuda.stream(d2h_stream): + # Wait for compute to complete + d2h_stream.wait_stream(compute_stream) + for j, i in enumerate(indices): + M[i].copy_(m_batch[j], non_blocking=True) + X[i].copy_(x_batch[j], non_blocking=True) + + # Pipeline: prefetch first batch + if n_tensors > 0: + prefetch_data = prefetch_batch(0) + + # Main loop with double buffering + for batch_start in range(0, n_tensors, BATCH_SIZE): + # Get current batch (already prefetched) + g_batch, m_batch, x_batch, indices = prefetch_data + + # Start prefetching NEXT batch (overlaps with current compute) + next_start = batch_start + BATCH_SIZE + if next_start < n_tensors: + prefetch_data = prefetch_batch(next_start) + + # Compute current batch + m_batch, x_batch = compute_batch(g_batch, m_batch, x_batch, indices) + + # Writeback current batch (overlaps with next iteration's prefetch/compute) + writeback_batch(m_batch, x_batch, indices) + + # Single sync at end to ensure all D2H transfers complete + torch.cuda.synchronize() + + yield # Single yield to make this a generator + else: + # STANDARD GPU MODE: Process all tensors together (original behavior) + U = muon_update_pre_orthogonalize( + G=G, + M=M, + momentum=momentum, + nesterov=nesterov, + ) + + # Orthogonalize each tensor locally + # Newton-Schulz treats dim 0 as batch, processing each slice independently + U = [ + muon_update_newton_schulz( + u, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + for u in U + ] - yield # Single yield to make this a generator + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + yield # Single yield to make this a generator def muon_update_batch_async( @@ -673,146 +779,333 @@ def muon_update_batch_async( Batched version of Muon update. Batch size should be equal to number of GPUs. All tensors in a batch should have identical shape, sharding, and dtype. Identical hyperparameters are used for all tensors in the batch. + + Memory-optimized for CPU offloading: when tensors are on CPU, moves ALL computation + to GPU (momentum update, all_to_all, Newton-Schulz, weight update) then copies back. """ assert len(X) == len(G) assert len(X) == len(M) assert len(X) == world_size - # Update momentum and compute the inputs for orthogonalization - U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - momentum=momentum, - nesterov=nesterov, - ) - - # Get one whole matrix for each device to orthogonalize - if shard_dim is not None: - # Use all-to-all to transform from a batch of shards to a single whole matrix - # https://www.essential.ai/blog/infra - assert ( - process_group is not None - ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert not isinstance(U[0], DTensor), "U should contain local shards" - - # Debug: print full tensor info before the divisibility check - x0 = X[0] - x0_mesh = x0.device_mesh - x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} - - assert ( - X[0].size(shard_dim) % world_size == 0 - ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ - f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ - f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" - - # Allocate buffers to receive shards of one whole matrix from other devices - single_matrix_shards = [torch.empty_like(u) for u in U] - - # Redistribute the shards to form one unique full tensor on each device - # Sync CUDA before collective to ensure all prior GPU ops are complete - # This can prevent NCCL hangs due to async GPU operations + # Check early if we're in CPU offloading mode + G_local = to_local(G) + M_local = to_local(M) + X_local = to_local(X) + original_device = M_local[0].device + needs_gpu_transfer = original_device.type != "cuda" + + if needs_gpu_transfer: + # ====== CPU OFFLOADING PATH: Do ALL computation on GPU ====== + # This avoids slow CPU foreach operations for momentum and weight updates + cuda_device = torch.device("cuda") + dtype = M_local[0].dtype + + # Transfer G, M to GPU for momentum update + G_gpu = [g.to(dtype=dtype).to(cuda_device, non_blocking=True) for g in G_local] + M_gpu = [m.to(cuda_device, non_blocking=True) for m in M_local] torch.cuda.synchronize() - # N sequential all_gathers - only keep result for our assigned param - single_matrix_shards = None - for param_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + # Momentum update on GPU (equivalent to muon_update_pre_orthogonalize) + torch._foreach_mul_(M_gpu, momentum) + torch._foreach_add_(M_gpu, G_gpu) - # All ranks send their shard of param_idx - dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + if nesterov: + U_gpu = torch._foreach_mul(M_gpu, momentum) + torch._foreach_add_(U_gpu, G_gpu) + else: + # U shares memory with M when not using nesterov + U_gpu = M_gpu + + # Free G_gpu - no longer needed + del G_gpu + + # Convert to bfloat16 for communication + U_gpu = [u.to(dtype=torch.bfloat16) for u in U_gpu] + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + + # Validation + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Make contiguous for all_to_all + U_gpu = [u.contiguous() for u in U_gpu] + + # First all_to_all: batch of shards -> single whole matrix + single_matrix_shards = [torch.empty_like(U_gpu[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, U_gpu, group=process_group) + del U_gpu - # Only keep if this is our assigned parameter - if param_idx == device_rank: - single_matrix_shards = gathered - # Otherwise 'gathered' goes out of scope and memory can be freed + yield - yield + # Concatenate shards to form whole matrix + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards - # Concatentate shards to form a whole matrix to orthogonalize - single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) + # Newton-Schulz orthogonalization (on GPU) + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Split result back into shards - # Contiguous is needed for communication to work correctly - orth_shards = [ - x.contiguous() - for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) - ] + # Split result back into shards + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix - # N sequential all_gathers - collect results as we go - for shard_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + # Second all_to_all to redistribute orthogonalized shards + U_orth_gpu = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(U_orth_gpu, orth_shards, group=process_group) + del orth_shards - # All ranks send their shard at index shard_idx - dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + yield - # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} - # We need U[r] = O^r_{device_rank} - # So when shard_idx == device_rank: U[r] = gathered[r] for all r - if shard_idx == device_rank: - for r in range(world_size): - U[r].copy_(gathered[r]) + else: + # Matrices are not sharded, orthogonalize directly + single_matrix = U_gpu[device_rank] + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + if process_group is not None and process_group.size() > 1: + U_orth_gpu = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_orth_gpu, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + else: + assert world_size == 1 + U_orth_gpu = [single_matrix] + + # Compute scaled learning rate (use full tensor shape from X[0]) + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Transfer X to GPU for weight update + X_gpu = [x.to(cuda_device, non_blocking=True) for x in X_local] + torch.cuda.synchronize() + + # Weight update on GPU (equivalent to muon_update_post_orthogonalize) + torch._foreach_mul_(X_gpu, 1 - lr * weight_decay) + U_scaled = torch._foreach_mul(U_orth_gpu, adjusted_lr) + torch._foreach_sub_(X_gpu, U_scaled) + del U_scaled, U_orth_gpu + + # Copy M and X back to CPU + for i in range(world_size): + M_local[i].copy_(M_gpu[i], non_blocking=True) + X_local[i].copy_(X_gpu[i], non_blocking=True) - yield + torch.cuda.synchronize() + del M_gpu, X_gpu else: - # Matrices are not sharded, so we can directly orthogonalize - # Get a single matrix corresponding to this device - single_matrix = U[device_rank] - assert not isinstance(single_matrix, DTensor) - - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, + # ====== STANDARD GPU PATH ====== + # Update momentum and compute the inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=G_local, + M=M_local, + momentum=momentum, + nesterov=nesterov, ) - if process_group is not None and process_group.size() > 1: - # Allocate empty tensors to receive updates from other devices - U = [torch.empty_like(u) for u in U] + # Get one whole matrix for each device to orthogonalize + # JQ: This is the N sequential gather version + # if shard_dim is not None: + # # Use all-to-all to transform from a batch of shards to a single whole matrix + # # https://www.essential.ai/blog/infra + # assert ( + # process_group is not None + # ), "process_group must be provided for sharded DTensors" + # assert isinstance(X[0], DTensor), "X should contain DTensors" + # assert not isinstance(U[0], DTensor), "U should contain local shards" + + # # Debug: print full tensor info before the divisibility check + # x0 = X[0] + # x0_mesh = x0.device_mesh + # x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + + # assert ( + # X[0].size(shard_dim) % world_size == 0 + # ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + # f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + # f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # # Allocate buffers to receive shards of one whole matrix from other devices + # single_matrix_shards = [torch.empty_like(u) for u in U] + + # # Redistribute the shards to form one unique full tensor on each device + # # Sync CUDA before collective to ensure all prior GPU ops are complete + # # This can prevent NCCL hangs due to async GPU operations + # torch.cuda.synchronize() + + # # N sequential all_gathers - only keep result for our assigned param + # single_matrix_shards = None + # for param_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + + # # All ranks send their shard of param_idx + # dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + + # # Only keep if this is our assigned parameter + # if param_idx == device_rank: + # single_matrix_shards = gathered + # # Otherwise 'gathered' goes out of scope and memory can be freed + + # yield + + # # Concatentate shards to form a whole matrix to orthogonalize + # single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + # single_matrix = muon_update_newton_schulz( + # single_matrix, + # newton_schulz_func=newton_schulz_func, + # flatten=flatten, + # epsilon=epsilon, + # ) + + # # Split result back into shards + # # Contiguous is needed for communication to work correctly + # orth_shards = [ + # x.contiguous() + # for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + # ] + + # # N sequential all_gathers - collect results as we go + # for shard_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + + # # All ranks send their shard at index shard_idx + # dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + + # # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} + # # We need U[r] = O^r_{device_rank} + # # So when shard_idx == device_rank: U[r] = gathered[r] for all r + # if shard_idx == device_rank: + # for r in range(world_size): + # U[r].copy_(gathered[r]) + + # yield + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert not isinstance(U[0], DTensor), "U should contain local shards" + + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Sync CUDA before collective to prevent NCCL hangs from async GPU ops + torch.cuda.synchronize() + + single_matrix_shards = [torch.empty_like(U[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, [u.contiguous() for u in U], group=process_group) - # All gather orthogonalized results from other devices into buffer - work = dist.all_gather( - U, single_matrix.contiguous(), group=process_group, async_op=True + yield + + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, ) + + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix + + output_shards = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(output_shards, orth_shards, group=process_group) + del orth_shards + + for i in range(world_size): + U[i].copy_(output_shards[i]) + del output_shards + yield - work.wait() else: - # Single GPU case, no need to gather - assert world_size == 1 - U = [single_matrix] - - # Compute scaled learning rate - # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + single_matrix = U[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=to_local(X), - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if process_group is not None and process_group.size() > 1: + U_gathered = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_gathered, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + U = U_gathered + else: + assert world_size == 1 + U = [single_matrix] + + # Compute scaled learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X_local, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) def adamw_update_foreach_async( From e7ccfdc38aa50177c53cf8b001632136d861085e Mon Sep 17 00:00:00 2001 From: emozilla Date: Thu, 29 Jan 2026 12:12:47 -0800 Subject: [PATCH 118/127] merge fixups --- torchtitan/models/deepseek_v3/__init__.py | 1 - torchtitan/models/moe/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index b85d016974..c9c7f5f755 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -173,7 +173,6 @@ route_scale=2.827, score_before_experts=False, ), - n_expert_groups=1, n_limited_groups=1, q_lora_rank=1536, kv_lora_rank=512, diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index dff7afd537..1fccdaa572 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import build_moe, ExpertRoutingHistogram, FeedForward, MoE, MoEArgs +from .moe import build_moe, fast_init_trunc_normal_, fast_init_normal_, ExpertRoutingHistogram, FeedForward, MoE, MoEArgs __all__ = ["FeedForward", "MoE", "MoEArgs", "build_moe", "ExpertRoutingHistogram", "fast_init_trunc_normal_", "fast_init_normal_"] From 98f53ee7edd0cdf1801f23a67546e66b07612650 Mon Sep 17 00:00:00 2001 From: emozilla Date: Thu, 29 Jan 2026 19:53:36 -0800 Subject: [PATCH 119/127] merge fixups --- torchtitan/models/deepseek_v3/__init__.py | 7 +++---- torchtitan/models/deepseek_v3/model/args.py | 10 ---------- torchtitan/models/moe/moe.py | 4 ++-- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index c9c7f5f755..9142acc839 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -7,8 +7,8 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing -from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.experiments.kimi_linear.model.tokenizer import build_kimi_tokenizer from torchtitan.hf_datasets.dataloader import build_dataloader from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import TrainSpec @@ -173,13 +173,12 @@ route_scale=2.827, score_before_experts=False, ), - n_limited_groups=1, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", rope_theta=50000.0, rope_factor=32.0, @@ -197,7 +196,7 @@ def get_train_spec() -> TrainSpec: build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_dataloader, - build_tokenizer_fn=build_hf_tokenizer, + build_tokenizer_fn=build_kimi_tokenizer, # falls back to hf tokenizer if tiktoken.model not found build_loss_fn=build_cross_entropy_loss, state_dict_adapter=DeepSeekV3StateDictAdapter, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index aba564927c..d7ed015300 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -99,16 +99,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if ( - job_config.parallelism.context_parallel_degree > 1 - and self.attn_type != "sdpa" - ): - raise NotImplementedError( - f"Context Parallel only supports SDPA attention. " - f"Got attn_type='{self.attn_type}'. " - f"FlexAttention and varlen attention are not supported with CP." - ) - self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 2fc040cbe2..9caa4a1d9f 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -1055,7 +1055,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i def build_moe( - args: MoEArgs, dim: int, hidden_dim: int, moe_impl: str = "standard" + args: MoEArgs, dim: int, hidden_dim: int, peft_config: PEFT, moe_impl: str = "standard", ) -> nn.Module: """Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP).""" if moe_impl == "deepep": @@ -1069,4 +1069,4 @@ def build_moe( logger.info( f"Standard MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}" ) - return MoE(args, dim=dim, hidden_dim=hidden_dim) + return MoE(args, dim=dim, hidden_dim=hidden_dim, peft_config=peft_config) From 668f23e3e17b21856b00a6ee7c4e4ebe8662e9a4 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 08:48:44 -0800 Subject: [PATCH 120/127] Add memory tracking and BF16 optimizer state features with Kimi K2 configs Memory Tracking Tools: - Add DetailedMemoryTracker for per-phase memory tracking (before_forward, after_forward_backward, after_optimizer, step_end) - Add CUDAMemoryTracker for PyTorch vs nvidia-smi memory comparison - Add AggressiveMemoryManager for CUDA fragmentation reduction with modes: minimal, balanced, aggressive, maximum BF16 Optimizer States: - Add BF16StateOptimizersContainer wrapper for pre-initializing optimizer states in bfloat16 before first step (50% memory savings) - Add preinit_optimizer_states_bf16() to allocate exp_avg/exp_avg_sq in param dtype from the start, avoiding fp32 allocation spike - Fix device mismatch bug: state["step"] tensor now created on param device New Config Options: - optimizer.state_dtype: "float32" | "bfloat16" - training.enable_detailed_memory_tracking: bool - training.clear_cache_between_steps: bool - training.skip_optimizer_step: bool - training.aggressive_memory_mode: "minimal" | "balanced" | "aggressive" | "maximum" - training.aggressive_memory_verbose: bool Train Loop Integration: - Initialize memory trackers in Trainer.__init__ - Call tracking at forward_backward_step and train_step phases - Call aggressive memory manager at post_backward, post_optimizer, step_end - Pre-initialize BF16 optimizer states before training loop Configs Added: - qwen3_30b_a3b_memory_test.toml: Test config for memory features - kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml: 12-node production config - kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml: 36-node HSDP config --- qwen3_30b_a3b_memory_test.toml | 73 +++ torchtitan/components/optimizer.py | 120 ++++- torchtitan/config/job_config.py | 42 ++ .../kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml | 72 +++ ..._32k_ctx_hsdp_replicate3_shard6_lbs10.toml | 73 +++ torchtitan/tools/aggressive_memory_manager.py | 414 ++++++++++++++++++ torchtitan/tools/cuda_memory_tracker.py | 123 ++++++ torchtitan/tools/detailed_memory_tracker.py | 160 +++++++ torchtitan/train.py | 121 ++++- 9 files changed, 1195 insertions(+), 3 deletions(-) create mode 100644 qwen3_30b_a3b_memory_test.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml create mode 100644 torchtitan/tools/aggressive_memory_manager.py create mode 100644 torchtitan/tools/cuda_memory_tracker.py create mode 100644 torchtitan/tools/detailed_memory_tracker.py diff --git a/qwen3_30b_a3b_memory_test.toml b/qwen3_30b_a3b_memory_test.toml new file mode 100644 index 0000000000..7114927040 --- /dev/null +++ b/qwen3_30b_a3b_memory_test.toml @@ -0,0 +1,73 @@ +# Qwen3 30B-A3B with memory tracking features enabled +# Tests: detailed memory tracking, aggressive memory manager, bf16 optimizer states +# Reduced settings to fit in memory + +[job] +dump_folder = "./outputs/qwen3_30b_a3b_memory_test" +description = "Qwen3 30B-A3B - memory tracking test" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "30B-A3B" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 +# Test BF16 optimizer states (50% memory savings) +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 1 +seq_len = 2048 +max_norm = 1.0 +steps = 3 +dataset = "c4" +enable_cpu_offload = true +# Enable detailed memory tracking +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +# Enable aggressive memory management +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +context_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 8 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = false +export_dtype = "float16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = "op" + +[compile] +enable = true +components = ["loss"] diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2bdd55f0d5..c271923fdc 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -23,6 +23,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config import Optimizer as OptimizerConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger # Dion optimizer availability will be checked lazily when needed DION_AVAILABLE = None @@ -78,6 +79,115 @@ def _check_muon_availability(): T = TypeVar("T", bound=Optimizer) +def preinit_optimizer_states_bf16(optimizers_container: "OptimizersContainer") -> None: + """ + Pre-initialize optimizer states (exp_avg, exp_avg_sq) directly in bfloat16. + This MUST be called BEFORE the first optimizer.step() to avoid fp32 allocation spike. + + This reduces optimizer state memory by ~50% (from fp32 to bf16). + States are allocated in bf16 from the start, avoiding the memory spike from fp32 allocation. + """ + total_params = 0 + total_bytes = 0 + dtype_device_samples = [] + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + for opt_idx, optimizer in enumerate(optimizers_container.optimizers): + for pg_idx, param_group in enumerate(optimizer.param_groups): + for p_idx, p in enumerate(param_group["params"]): + if p.requires_grad: + if total_params < 5: + dtype_device_samples.append( + f"param[{opt_idx}][{pg_idx}][{p_idx}]: dtype={p.dtype}, device={p.device}, shape={list(p.shape)}" + ) + + state = optimizer.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + total_params += 1 + bytes_per_element = 2 if p.dtype == torch.bfloat16 else 4 + total_bytes += p.numel() * 2 * bytes_per_element + + if total_params <= 3: + logger.info( + f"[Rank {rank}] State init sample: param dtype={p.dtype}, device={p.device}, " + f"exp_avg dtype={state['exp_avg'].dtype}, device={state['exp_avg'].device}" + ) + + for sample in dtype_device_samples: + logger.info(f"[Rank {rank}] {sample}") + + logger.info( + f"[Rank {rank}] Pre-initialized {total_params} optimizer states matching param dtype, " + f"this rank: {total_bytes / 1e9:.2f} GB" + ) + + +class BF16StateOptimizersContainer(Generic[T]): + """ + Wrapper that pre-initializes optimizer states in bfloat16 BEFORE first step. + This prevents the memory spike from fp32 state allocation. + + IMPORTANT: Call init_bf16_states() BEFORE the first step() to avoid + rank skew during state allocation. This should be called after model + setup but before training starts, ideally with a barrier afterwards. + """ + + def __init__( + self, + base_container: "OptimizersContainer", + state_dtype: torch.dtype = torch.bfloat16, + ): + self._base = base_container + self._state_dtype = state_dtype + self._states_initialized = False + + def init_bf16_states(self): + """ + Pre-initialize optimizer states in bf16. + Call this BEFORE training starts, then call a distributed barrier. + This avoids rank skew during the first optimizer.step(). + """ + if not self._states_initialized: + logger.info("Pre-initializing optimizer states in bfloat16...") + preinit_optimizer_states_bf16(self._base) + self._states_initialized = True + logger.info("BF16 optimizer state pre-initialization complete.") + + def step(self, *args, **kwargs) -> None: + if not self._states_initialized: + logger.warning( + "BF16 optimizer states not pre-initialized! " + "Call init_bf16_states() before training to avoid rank skew." + ) + self.init_bf16_states() + self._base.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs) -> None: + self._base.zero_grad(*args, **kwargs) + + def state_dict(self) -> dict[str, Any]: + return self._base.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._base.load_state_dict(state_dict) + + def __iter__(self): + return iter(self._base) + + def __len__(self): + return len(self._base) + + def __getattr__(self, name): + return getattr(self._base, name) + + class OptimizersContainer(Optimizer, Stateful, Generic[T]): """A container for multiple optimizers. @@ -509,7 +619,15 @@ def build_optimizers( use_ft_optimizer=ft_manager.use_async_quorum, ) - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + container = OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + # Wrap with BF16 state container if configured + state_dtype = getattr(optimizer_config, "state_dtype", "float32") + if state_dtype == "bfloat16": + logger.info("Using bfloat16 optimizer states (will pre-init before first step)") + return BF16StateOptimizersContainer(container, torch.bfloat16) + + return container def build_optimizers_with_moe_load_balancing( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 99a0130f43..7a6a9ffb33 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -286,6 +286,13 @@ class Optimizer: use_triton: bool = False """Whether to use Triton kernel for Newton-Schulz in Muon optimizer.""" + state_dtype: Literal["float32", "bfloat16"] = "float32" + """ + Dtype for optimizer states (exp_avg, exp_avg_sq for Adam/AdamW). + Using bfloat16 reduces memory by ~50% but may affect training stability. + Only applies to Adam/AdamW optimizers. + """ + @dataclass class LRScheduler: @@ -438,6 +445,41 @@ class Training: dataloader: DataLoader = field(default_factory=DataLoader) """DataLoader configuration""" + enable_detailed_memory_tracking: bool = False + """ + Whether to enable detailed memory tracking at every training phase + """ + + clear_cache_between_steps: bool = False + """ + Whether to clear CUDA cache between training steps to measure minimum memory requirements + """ + + skip_optimizer_step: bool = False + """ + Whether to skip the optimizer step (for memory profiling purposes only) + """ + + aggressive_memory_mode: Literal[ + "minimal", "balanced", "aggressive", "maximum" + ] | None = None + """ + Enable aggressive memory management to reduce CUDA memory fragmentation. + This clears CUDA cache and Python GC at strategic points (post-backward, post-optimizer). + Modes: + - None: Disabled (default) + - "minimal": Only clear on high fragmentation (<1% overhead) + - "balanced": Clear after backward and optimizer (2-3% overhead) + - "aggressive": Clear frequently with sync (5-8% overhead) + - "maximum": Clear after every operation (10-15% overhead, for debugging) + """ + + aggressive_memory_verbose: bool = False + """ + Enable verbose logging for aggressive memory manager. + Logs detailed memory stats after each clear operation. + """ + @dataclass class Parallelism: diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml new file mode 100644 index 0000000000..e3e1c9d497 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml @@ -0,0 +1,72 @@ +# Kimi K2 - 12 nodes - EP=96, CP=16, LBS=11 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep/configs/exp1acd_12n_ep96_cp16_lbs11.toml +# Job ID: 2307 +# +# Expected Performance: +# - TPS: 402 +# - Memory: 67.55 GiB (85.2%) +# - MFU: 17.72% +# - TFLOPS: ~175 +# +# Parallelism: EP=96, CP=16, DP=1 (dp_replicate=1, dp_shard=1) +# Nodes: 12 (96 GPUs) +# Seq Length: 32768 +# Local Batch Size: 11 +# + +[job] +dump_folder = "./outputs/kimi_k2/12n_ep96_cp16_lbs11" +description = "Kimi K2 - 12n EP=96 CP=16 LBS=11" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 11 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml new file mode 100644 index 0000000000..7ee95edec6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml @@ -0,0 +1,73 @@ +# Kimi K2 - 36 nodes - EP=96, CP=16, HSDP (dp_replicate=3, dp_shard=6), LBS=10 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep_dp/configs/exp1aj_HSDP_r3_s6_lbs10.toml +# Job ID: 2485 +# +# Expected Performance: +# - TPS: 378 +# - Memory: 69.45 GiB (87.6%) +# - MFU: 16.64% +# +# Parallelism: EP=96, CP=16, dp_replicate=3, dp_shard=6 +# HSDP: Shard within 12 nodes, all-reduce between 3 replica groups +# Nodes: 36 (288 GPUs) +# Seq Length: 32768 +# Local Batch Size: 10 +# + +[job] +dump_folder = "./outputs/kimi_k2/36n_ep96_cp16_hsdp_replicate3_shard6_lbs10" +description = "Kimi K2 - 36n HSDP dp_replicate=3 dp_shard=6 LBS=10" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 10 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 3 +data_parallel_shard_degree = 6 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/tools/aggressive_memory_manager.py b/torchtitan/tools/aggressive_memory_manager.py new file mode 100644 index 0000000000..1c4861cb74 --- /dev/null +++ b/torchtitan/tools/aggressive_memory_manager.py @@ -0,0 +1,414 @@ +""" +Aggressive Memory Manager for reducing CUDA memory fragmentation. + +This module provides aggressive memory clearing strategies to minimize +fragmentation and allocation retries during distributed training. + +Usage: + from torchtitan.tools.aggressive_memory_manager import AggressiveMemoryManager + + # Initialize at start of training + mem_manager = AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + sync_before_clear=True, + defrag_threshold_mb=1000, # Defrag if fragmentation > 1GB + ) + + # In training loop: + loss.backward() + mem_manager.post_backward() + + optimizer.step() + mem_manager.post_optimizer() + + mem_manager.step_complete() +""" + +import gc +import os +import time +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + +from torchtitan.tools.logging import logger + + +@dataclass +class MemoryStats: + """Current memory statistics""" + + allocated: int + reserved: int + active: int + fragmentation: int + fragmentation_pct: float + num_alloc_retries: int + + +class AggressiveMemoryManager: + """ + Aggressive memory management to minimize CUDA memory fragmentation. + + Key strategies: + 1. Clear cache at strategic points (post-backward, post-optimizer) + 2. Synchronize before clearing to ensure all async ops complete + 3. Force garbage collection to release Python references + 4. Monitor fragmentation and trigger defrag when threshold exceeded + 5. Set optimal allocator configuration + + Args: + clear_after_backward: Clear cache after backward pass + clear_after_optimizer: Clear cache after optimizer step + clear_every_n_steps: Only clear every N steps (1 = every step) + sync_before_clear: Synchronize CUDA before clearing cache + defrag_threshold_mb: Trigger defrag if fragmentation exceeds this (MB) + gc_generation: Python GC generation to collect (0-2, higher = more thorough) + verbose: Log detailed memory stats + rank: Distributed rank (auto-detected if None) + """ + + def __init__( + self, + clear_after_backward: bool = True, + clear_after_optimizer: bool = True, + clear_every_n_steps: int = 1, + sync_before_clear: bool = True, + defrag_threshold_mb: float = 500.0, + gc_generation: int = 1, + verbose: bool = False, + rank: Optional[int] = None, + ): + self.clear_after_backward = clear_after_backward + self.clear_after_optimizer = clear_after_optimizer + self.clear_every_n_steps = clear_every_n_steps + self.sync_before_clear = sync_before_clear + self.defrag_threshold_mb = defrag_threshold_mb + self.gc_generation = gc_generation + self.verbose = verbose + + self.rank = ( + rank + if rank is not None + else (dist.get_rank() if dist.is_initialized() else 0) + ) + + self.step_count = 0 + self.total_clears = 0 + self.total_defrag_time_ms = 0.0 + + # Disable automatic GC - we'll control it manually + gc.disable() + + # Initial cleanup + self._aggressive_clear("initialization") + + if self.rank == 0: + logger.info( + f"[AggressiveMemoryManager] Initialized: " + f"clear_backward={clear_after_backward}, " + f"clear_optimizer={clear_after_optimizer}, " + f"every_n_steps={clear_every_n_steps}, " + f"sync={sync_before_clear}, " + f"defrag_threshold={defrag_threshold_mb}MB" + ) + + @staticmethod + def configure_allocator( + expandable_segments: bool = True, + max_split_size_mb: int = 128, + garbage_collection_threshold: float = 0.8, + roundup_power2_divisions: int = 4, + ) -> str: + """ + Configure PyTorch CUDA allocator for minimal fragmentation. + + Call this BEFORE any CUDA operations (before model creation). + + Args: + expandable_segments: Enable expandable memory segments + max_split_size_mb: Max size of memory splits (smaller = less fragmentation) + garbage_collection_threshold: Trigger GC when this fraction of memory is fragmented + roundup_power2_divisions: Memory rounding granularity + + Returns: + The PYTORCH_CUDA_ALLOC_CONF string that was set + """ + config_parts = [] + + if expandable_segments: + config_parts.append("expandable_segments:True") + + config_parts.append(f"max_split_size_mb:{max_split_size_mb}") + config_parts.append( + f"garbage_collection_threshold:{garbage_collection_threshold}" + ) + config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") + + config_str = ",".join(config_parts) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_str + + return config_str + + def get_memory_stats(self) -> MemoryStats: + """Get current memory statistics""" + if not torch.cuda.is_available(): + return MemoryStats(0, 0, 0, 0, 0.0, 0) + + stats = torch.cuda.memory_stats() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + active = stats.get("active_bytes.all.current", 0) + fragmentation = reserved - allocated + fragmentation_pct = (fragmentation / reserved * 100) if reserved > 0 else 0.0 + num_retries = stats.get("num_alloc_retries", 0) + + return MemoryStats( + allocated=allocated, + reserved=reserved, + active=active, + fragmentation=fragmentation, + fragmentation_pct=fragmentation_pct, + num_alloc_retries=num_retries, + ) + + def _should_clear(self) -> bool: + """Check if we should clear cache this step""" + return self.step_count % self.clear_every_n_steps == 0 + + def _aggressive_clear(self, reason: str) -> float: + """ + Perform aggressive memory clearing. + + Returns: + Time taken in milliseconds + """ + if not torch.cuda.is_available(): + return 0.0 + + start = time.perf_counter() + + # 1. Synchronize all CUDA streams to ensure ops complete + if self.sync_before_clear: + torch.cuda.synchronize() + + # 2. Python garbage collection (releases tensor references) + gc.collect(self.gc_generation) + + # 3. Clear CUDA cache (releases unused cached memory) + torch.cuda.empty_cache() + + # 4. Optional: Force synchronization after clear + if self.sync_before_clear: + torch.cuda.synchronize() + + elapsed_ms = (time.perf_counter() - start) * 1000 + self.total_clears += 1 + self.total_defrag_time_ms += elapsed_ms + + if self.verbose and self.rank == 0: + stats = self.get_memory_stats() + logger.info( + f"[AggressiveMemoryManager] {reason}: " + f"cleared in {elapsed_ms:.1f}ms, " + f"frag={stats.fragmentation_pct:.1f}%, " + f"reserved={stats.reserved/1e9:.2f}GB" + ) + + return elapsed_ms + + def _check_and_defrag(self, phase: str) -> bool: + """ + Check fragmentation and defrag if needed. + + Returns: + True if defrag was triggered + """ + stats = self.get_memory_stats() + fragmentation_mb = stats.fragmentation / (1024 * 1024) + + if fragmentation_mb > self.defrag_threshold_mb: + self._aggressive_clear(f"defrag_{phase}_frag={fragmentation_mb:.0f}MB") + return True + + return False + + def post_backward(self): + """Call after backward pass completes""" + if self.clear_after_backward and self._should_clear(): + self._check_and_defrag("post_backward") + self._aggressive_clear("post_backward") + + def post_optimizer(self): + """Call after optimizer step completes""" + if self.clear_after_optimizer and self._should_clear(): + self._check_and_defrag("post_optimizer") + self._aggressive_clear("post_optimizer") + + def step_complete(self): + """Call at the end of each training step""" + self.step_count += 1 + + # Always check for high fragmentation + self._check_and_defrag("step_end") + + def get_summary(self) -> str: + """Get summary of memory management activity""" + avg_time = self.total_defrag_time_ms / max(1, self.total_clears) + return ( + f"AggressiveMemoryManager Summary:\n" + f" Total clears: {self.total_clears}\n" + f" Total defrag time: {self.total_defrag_time_ms:.1f}ms\n" + f" Avg time per clear: {avg_time:.2f}ms\n" + f" Steps processed: {self.step_count}" + ) + + +class BackwardMemoryHook: + """ + Register hooks on model parameters to clear memory during backward pass. + + This clears memory incrementally as gradients are computed, rather than + waiting until the end of backward. + + Args: + clear_every_n_params: Clear cache after every N parameter gradients + sync_on_clear: Synchronize before clearing (slower but more thorough) + """ + + def __init__( + self, + clear_every_n_params: int = 10, + sync_on_clear: bool = False, + ): + self.clear_every_n_params = clear_every_n_params + self.sync_on_clear = sync_on_clear + self.param_count = 0 + self.handles = [] + + def _backward_hook(self, grad): + """Hook called when gradient is computed for a parameter""" + self.param_count += 1 + + if self.param_count % self.clear_every_n_params == 0: + if self.sync_on_clear: + torch.cuda.synchronize() + gc.collect(0) # Fast GC (generation 0 only) + torch.cuda.empty_cache() + + return grad + + def register(self, model: torch.nn.Module): + """Register hooks on all model parameters""" + for name, param in model.named_parameters(): + if param.requires_grad: + handle = param.register_post_accumulate_grad_hook( + lambda p, name=name: self._backward_hook(p.grad) + ) + self.handles.append(handle) + + logger.info( + f"[BackwardMemoryHook] Registered on {len(self.handles)} parameters, " + f"clearing every {self.clear_every_n_params} params" + ) + + def remove(self): + """Remove all registered hooks""" + for handle in self.handles: + handle.remove() + self.handles.clear() + + def reset_count(self): + """Reset parameter count (call at start of each backward)""" + self.param_count = 0 + + +def setup_aggressive_memory_environment(): + """ + Set up environment variables for aggressive memory management. + + Call this BEFORE importing torch or creating any CUDA tensors. + """ + # Optimal allocator settings for minimal fragmentation + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( + "expandable_segments:True," + "max_split_size_mb:128," + "garbage_collection_threshold:0.8," + "roundup_power2_divisions:4" + ) + + # Disable NCCL async error handling (can cause memory issues) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + # Force synchronous CUDA operations for debugging + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # Uncomment for debugging + + return os.environ.get("PYTORCH_CUDA_ALLOC_CONF") + + +# Convenience function for quick setup +def create_aggressive_memory_manager( + mode: str = "balanced", + verbose: bool = False, +) -> AggressiveMemoryManager: + """ + Create an AggressiveMemoryManager with preset configurations. + + Args: + mode: One of: + - "minimal": Only clear on high fragmentation + - "balanced": Clear after backward and optimizer + - "aggressive": Clear frequently with sync + - "maximum": Clear after every operation + verbose: Enable verbose logging + + Returns: + Configured AggressiveMemoryManager + """ + if mode == "minimal": + return AggressiveMemoryManager( + clear_after_backward=False, + clear_after_optimizer=False, + clear_every_n_steps=10, + sync_before_clear=False, + defrag_threshold_mb=2000, + gc_generation=0, + verbose=verbose, + ) + elif mode == "balanced": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=False, + defrag_threshold_mb=500, + gc_generation=1, + verbose=verbose, + ) + elif mode == "aggressive": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=200, + gc_generation=2, + verbose=verbose, + ) + elif mode == "maximum": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=100, + gc_generation=2, + verbose=verbose, + ) + else: + raise ValueError( + f"Unknown mode: {mode}. Use minimal/balanced/aggressive/maximum" + ) diff --git a/torchtitan/tools/cuda_memory_tracker.py b/torchtitan/tools/cuda_memory_tracker.py new file mode 100644 index 0000000000..0f7d7af5f4 --- /dev/null +++ b/torchtitan/tools/cuda_memory_tracker.py @@ -0,0 +1,123 @@ +"""Track CUDA memory directly from nvidia-smi and PyTorch""" +import logging +import subprocess +from typing import Dict, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class CUDAMemoryTracker: + """Track memory from both PyTorch and CUDA/nvidia-smi""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.device = torch.cuda.current_device() + self.device_name = torch.cuda.get_device_name(self.device) + + if self.enabled: + logger.info( + f"CUDAMemoryTracker enabled for device {self.device}: {self.device_name}" + ) + + def get_nvidia_smi_memory(self) -> Optional[Dict[str, int]]: + """Get memory from nvidia-smi""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.free,memory.total", + "--format=csv,noheader,nounits", + "-i", + str(self.device), + ], + capture_output=True, + text=True, + timeout=2, + ) + + if result.returncode == 0: + used, free, total = map(int, result.stdout.strip().split(",")) + return {"used_mb": used, "free_mb": free, "total_mb": total} + except Exception as e: + logger.warning(f"Failed to get nvidia-smi memory: {e}") + + return None + + def get_pytorch_memory(self) -> Dict[str, int]: + """Get memory from PyTorch""" + stats = torch.cuda.memory_stats(self.device) + + return { + "reserved_bytes": torch.cuda.memory_reserved(self.device), + "allocated_bytes": torch.cuda.memory_allocated(self.device), + "active_bytes": stats.get("active_bytes.all.current", 0), + "inactive_bytes": stats.get("inactive_split_bytes.all.current", 0), + "peak_active_bytes": stats.get("active_bytes.all.peak", 0), + "num_alloc_retries": stats.get("num_alloc_retries.all.current", 0), + "num_ooms": stats.get("num_ooms.all.current", 0), + } + + def get_cuda_device_memory(self) -> Dict[str, int]: + """Get memory directly from CUDA device properties""" + props = torch.cuda.get_device_properties(self.device) + + return { + "total_memory": props.total_memory, + "reserved_memory": torch.cuda.memory_reserved(self.device), + "allocated_memory": torch.cuda.memory_allocated(self.device), + } + + def measure_all(self, phase: str, step: int): + """Comprehensive memory measurement""" + if not self.enabled: + return + + # PyTorch memory + pytorch_mem = self.get_pytorch_memory() + + # CUDA device memory + cuda_mem = self.get_cuda_device_memory() + + # nvidia-smi memory (if available) + smi_mem = self.get_nvidia_smi_memory() + + # Calculate fragmentation + reserved = pytorch_mem["reserved_bytes"] + allocated = pytorch_mem["allocated_bytes"] + active = pytorch_mem["active_bytes"] + + fragmentation = reserved - allocated + frag_pct = (fragmentation / reserved * 100) if reserved > 0 else 0 + + # Log PyTorch view + logger.info( + f"[PyTorch] Step {step:2d} | {phase:25s} | " + f"Reserved: {reserved/1e9:6.2f} GB | " + f"Allocated: {allocated/1e6:8.2f} MB | " + f"Active: {active/1e6:8.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + # Log CUDA/nvidia-smi view + if smi_mem: + logger.info( + f"[CUDA-SMI] Step {step:2d} | {phase:25s} | " + f"Used: {smi_mem['used_mb']/1024:6.2f} GB | " + f"Free: {smi_mem['free_mb']/1024:6.2f} GB | " + f"Total: {smi_mem['total_mb']/1024:6.2f} GB" + ) + + # Log comparison + if smi_mem: + pytorch_used_gb = reserved / 1e9 + smi_used_gb = smi_mem["used_mb"] / 1024 + diff_gb = smi_used_gb - pytorch_used_gb + + logger.info( + f"[Compare] Step {step:2d} | {phase:25s} | " + f"PyTorch reports: {pytorch_used_gb:6.2f} GB | " + f"nvidia-smi reports: {smi_used_gb:6.2f} GB | " + f"Diff: {diff_gb:+6.2f} GB" + ) diff --git a/torchtitan/tools/detailed_memory_tracker.py b/torchtitan/tools/detailed_memory_tracker.py new file mode 100644 index 0000000000..7b513b3e20 --- /dev/null +++ b/torchtitan/tools/detailed_memory_tracker.py @@ -0,0 +1,160 @@ +"""Detailed memory tracking throughout training step""" +import logging +from typing import Dict, List + +import torch + +logger = logging.getLogger(__name__) + + +class DetailedMemoryTracker: + """Track memory at every phase of training with cache clearing""" + + def __init__(self, enabled: bool = True, clear_cache: bool = True): + self.enabled = enabled + self.clear_cache_between_steps = clear_cache + self.measurements: List[Dict] = [] + self.device = torch.cuda.current_device() + + if self.enabled: + logger.info(f"DetailedMemoryTracker enabled (clear_cache={clear_cache})") + + def measure(self, phase: str, step: int): + """Capture memory state at a specific phase""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + measurement = { + "step": step, + "phase": phase, + "reserved": torch.cuda.memory_reserved(self.device), + "allocated": torch.cuda.memory_allocated(self.device), + "active": stats.get("active_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc_retries.all.current", 0), + } + + self.measurements.append(measurement) + + # Calculate fragmentation + fragmentation = measurement["reserved"] - measurement["allocated"] + frag_pct = ( + (fragmentation / measurement["reserved"] * 100) + if measurement["reserved"] > 0 + else 0 + ) + + logger.info( + f"[MemTrack] Step {step} | {phase:20s} | " + f"Reserved: {measurement['reserved']/1e9:6.2f} GB | " + f"Allocated: {measurement['allocated']/1e6:7.2f} MB | " + f"Active: {measurement['active']/1e6:7.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + def clear_cache_and_measure(self, phase: str, step: int): + """Clear cache and measure to see minimum memory""" + if not self.enabled: + return + + # Measure before clearing + self.measure(f"{phase}_before_clear", step) + + # Clear cache + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Measure after clearing + self.measure(f"{phase}_after_clear", step) + + def step_complete(self, step: int): + """Called after each training step""" + if not self.enabled: + return + + if self.clear_cache_between_steps: + self.clear_cache_and_measure("step_end", step) + + def get_summary(self) -> str: + """Get summary of all measurements""" + if not self.measurements: + return "No measurements recorded" + + summary = ["", "=" * 100, "DETAILED MEMORY TRACKING SUMMARY", "=" * 100, ""] + + # Group by step + steps = {} + for m in self.measurements: + step = m["step"] + if step not in steps: + steps[step] = [] + steps[step].append(m) + + for step, measures in sorted(steps.items()): + summary.append(f"\nStep {step}:") + summary.append( + f"{'Phase':<30} {'Reserved':>12} {'Allocated':>12} {'Active':>12} {'Frag%':>8}" + ) + summary.append("-" * 80) + + for m in measures: + frag_pct = ( + ((m["reserved"] - m["allocated"]) / m["reserved"] * 100) + if m["reserved"] > 0 + else 0 + ) + summary.append( + f"{m['phase']:<30} " + f"{m['reserved']/1e9:10.2f} GB " + f"{m['allocated']/1e6:10.2f} MB " + f"{m['active']/1e6:10.2f} MB " + f"{frag_pct:7.1f}%" + ) + + # Peak measurements + summary.append("\n" + "=" * 100) + summary.append("PEAK MEASUREMENTS ACROSS ALL STEPS:") + summary.append("=" * 100) + + peak_reserved = max(m["reserved"] for m in self.measurements) + peak_allocated = max(m["allocated"] for m in self.measurements) + peak_active = max(m["active"] for m in self.measurements) + + peak_reserved_phase = [ + m for m in self.measurements if m["reserved"] == peak_reserved + ][0] + peak_allocated_phase = [ + m for m in self.measurements if m["allocated"] == peak_allocated + ][0] + peak_active_phase = [ + m for m in self.measurements if m["active"] == peak_active + ][0] + + summary.append( + f"Peak Reserved: {peak_reserved/1e9:7.2f} GB at Step {peak_reserved_phase['step']} ({peak_reserved_phase['phase']})" + ) + step = peak_allocated_phase["step"] + phase = peak_allocated_phase["phase"] + summary.append( + f"Peak Allocated: {peak_allocated/1e6:7.2f} MB at Step {step} ({phase})" + ) + summary.append( + f"Peak Active: {peak_active/1e6:7.2f} MB at Step {peak_active_phase['step']} ({peak_active_phase['phase']})" + ) + + # Minimum after cache clear + cleared_measures = [m for m in self.measurements if "after_clear" in m["phase"]] + if cleared_measures: + min_reserved_cleared = min(m["reserved"] for m in cleared_measures) + min_measure = [ + m for m in cleared_measures if m["reserved"] == min_reserved_cleared + ][0] + summary.append( + f"\nMinimum Reserved (after cache clear): {min_reserved_cleared/1e9:7.2f} GB at Step {min_measure['step']}" + ) + summary.append(f" Active at minimum: {min_measure['active']/1e6:7.2f} MB") + + summary.append("=" * 100) + return "\n".join(summary) diff --git a/torchtitan/train.py b/torchtitan/train.py index 15e81cd018..bb9060bcf0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -30,6 +30,9 @@ from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils +from torchtitan.tools.aggressive_memory_manager import create_aggressive_memory_manager +from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker +from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, @@ -106,6 +109,40 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # Initialize detailed memory tracker + self.detailed_memory_tracker = DetailedMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + clear_cache=getattr( + job_config.training, "clear_cache_between_steps", False + ), + ) + + # Initialize CUDA memory tracker + self.cuda_memory_tracker = CUDAMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + ) + + # Initialize aggressive memory manager to reduce CUDA fragmentation + aggressive_mem_mode = getattr( + job_config.training, "aggressive_memory_mode", None + ) + if aggressive_mem_mode: + self.aggressive_mem_manager = create_aggressive_memory_manager( + mode=aggressive_mem_mode, + verbose=getattr( + job_config.training, "aggressive_memory_verbose", False + ), + ) + logger.info( + f"Aggressive memory manager enabled (mode={aggressive_mem_mode})" + ) + else: + self.aggressive_mem_manager = None + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -690,11 +727,30 @@ def forward_backward_step( del pred loss.backward() + # Aggressive memory clearing after backward to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_backward() + return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): + # AGGRESSIVE cache clearing before step for accurate memory measurements + if self.job_config.training.aggressive_memory_mode: + import gc + + torch.cuda.synchronize() + gc.collect(0) + gc.collect(1) + gc.collect(2) + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect(2) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self.metrics_processor.device_memory_monitor.reset_peak_stats() + self.optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -703,6 +759,10 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + # Track memory before forward pass + self.detailed_memory_tracker.measure("before_forward", self.step) + self.cuda_memory_tracker.measure_all("before_forward", self.step) + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -712,6 +772,10 @@ def train_step( loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + # Track memory after forward/backward + self.detailed_memory_tracker.measure("after_forward_backward", self.step) + self.cuda_memory_tracker.measure_all("after_forward_backward", self.step) + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, @@ -720,8 +784,39 @@ def train_step( ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() + + # Skip optimizer step if configured (for memory profiling) + if not self.job_config.training.skip_optimizer_step: + import datetime + import time as _time + + # Log step start with timestamp for correlation + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info(f"[STEP {self.step}] optimizer.step() START @ {_ts}") + + _optim_start = _time.time() + self.optimizers.step() + _optim_elapsed = _time.time() - _optim_start + + # Aggressive memory clearing after optimizer to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_optimizer() + + # Log step end with timing + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info( + f"[STEP {self.step}] optimizer.step() END @ {_ts} | Duration: {_optim_elapsed:.2f}s" + ) + + self.lr_schedulers.step() + else: + logger.info("Skipping optimizer step (skip_optimizer_step=True)") + + # Track memory after optimizer step + self.detailed_memory_tracker.measure("after_optimizer", self.step) + self.cuda_memory_tracker.measure_all("after_optimizer", self.step) # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -762,11 +857,25 @@ def train_step( extra_metrics=extra_metrics, ) + # Signal step complete to aggressive memory manager (triggers defrag check) + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.step_complete() + @record def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + + # Pre-initialize bf16 optimizer states if configured + # This must happen BEFORE training to avoid rank skew during first step + if hasattr(self.optimizers, "init_bf16_states"): + self.optimizers.init_bf16_states() + # Barrier to ensure all ranks finish before training starts + if torch.distributed.is_initialized(): + torch.distributed.barrier() + logger.info("All ranks synchronized after bf16 optimizer state init") + logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( @@ -842,6 +951,10 @@ def train(self): if memory_profiler: memory_profiler.step() + # Track memory at step end and optionally clear cache + self.detailed_memory_tracker.step_complete(self.step) + self.cuda_memory_tracker.measure_all("step_end", self.step) + # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -856,6 +969,10 @@ def train(self): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) + # Log detailed memory tracking summary + if torch.distributed.get_rank() == 0: + logger.info(self.detailed_memory_tracker.get_summary()) + logger.info("Training completed") def should_continue_training(self) -> bool: From 40714548492658fdfa918d181f42442cdcb82004 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 08:56:02 -0800 Subject: [PATCH 121/127] Add NaN tracker config, FSDP prefetch control, and nvidia-smi memory reporting Config Options Added: [parallelism] - fsdp_disable_prefetch: Disable FSDP forward/backward prefetching (reduces memory at cost of less communication overlap) [debug] - enable_nan_tracker: Enable lightweight NaN/Inf tracking to find where NaN first appears in the model - nan_tracker_verbose: Print stats for every layer (very verbose) Enhanced Metrics: - Add nvidia-smi memory reporting to DeviceMemStats for verification - Add _get_nvidia_smi_memory() method to DeviceMemoryMonitor - Handles CUDA_VISIBLE_DEVICES remapping for SLURM environments FSDP Prefetch Control: - Add disable_prefetch parameter to apply_fsdp() in llama4 - Wire up fsdp_disable_prefetch config to apply_fsdp calls in: llama4, deepseek_v3, qwen3, gpt_oss Test Configs: - qwen3_30b_a3b_memory_test.toml: Added [debug] section with nan_tracker Note: fsdp_bucket_cap_mb not added as it's not supported by FSDP2 API --- qwen3_30b_a3b_memory_test.toml | 5 ++ torchtitan/components/metrics.py | 49 +++++++++++++++++++ torchtitan/config/job_config.py | 12 +++++ .../models/deepseek_v3/infra/parallelize.py | 1 + .../models/gpt_oss/infra/parallelize.py | 1 + torchtitan/models/llama4/infra/parallelize.py | 8 +++ torchtitan/models/qwen3/infra/parallelize.py | 1 + 7 files changed, 77 insertions(+) diff --git a/qwen3_30b_a3b_memory_test.toml b/qwen3_30b_a3b_memory_test.toml index 7114927040..aba127bbd9 100644 --- a/qwen3_30b_a3b_memory_test.toml +++ b/qwen3_30b_a3b_memory_test.toml @@ -71,3 +71,8 @@ selective_ac_option = "op" [compile] enable = true components = ["loss"] + +[debug] +# Test NaN tracker +enable_nan_tracker = true +nan_tracker_verbose = false diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index d01d98f847..b3e4499329 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -36,6 +36,8 @@ "max_reserved_pct", "num_alloc_retries", "num_ooms", + "nvidia_smi_used_gib", # nvidia-smi reported memory for verification + "nvidia_smi_used_pct", ], ) @@ -63,6 +65,48 @@ def _to_gib(self, memory_in_bytes): def _to_pct(self, memory): return 100 * memory / self.device_capacity + def _get_nvidia_smi_memory(self): + """Get GPU memory usage from nvidia-smi for verification.""" + try: + import subprocess + + # In SLURM with CUDA_VISIBLE_DEVICES, PyTorch device index 0-7 maps to + # physical GPUs listed in CUDA_VISIBLE_DEVICES. We need the physical GPU index. + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible: + # CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" means device 0 is physical GPU 0 + # But it could also be "4,5,6,7,0,1,2,3" meaning device 0 is physical GPU 4 + visible_gpus = [ + int(x.strip()) for x in cuda_visible.split(",") if x.strip() + ] + if self.device_index < len(visible_gpus): + physical_gpu_index = visible_gpus[self.device_index] + else: + physical_gpu_index = self.device_index + else: + physical_gpu_index = self.device_index + + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + f"--id={physical_gpu_index}", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # nvidia-smi reports in MiB + used_mib = float(result.stdout.strip()) + used_gib = used_mib / 1024 + used_pct = (used_mib * 1024 * 1024) / self.device_capacity * 100 + return used_gib, used_pct + except Exception: + pass + return -1.0, -1.0 + def get_peak_stats(self): device_info = device_module.memory_stats(self.device) @@ -84,6 +128,9 @@ def get_peak_stats(self): if num_ooms > 0: logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") + # Get nvidia-smi memory for verification + nvidia_smi_gib, nvidia_smi_pct = self._get_nvidia_smi_memory() + return DeviceMemStats( max_active_gib, max_active_pct, @@ -91,6 +138,8 @@ def get_peak_stats(self): max_reserved_pct, num_retries, num_ooms, + nvidia_smi_gib, + nvidia_smi_pct, ) def reset_peak_stats(self): diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7a6a9ffb33..cc1e587598 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -521,6 +521,12 @@ class Parallelism: - "never" will disable `reshard_after_forward` for all forward passes. """ + fsdp_disable_prefetch: bool = False + """ + Whether to disable FSDP forward/backward prefetching. Disabling prefetch can reduce memory + at the cost of performance (less overlap of communication and computation). + """ + tensor_parallel_degree: int = 1 """Tensor Parallelism degree. 1 means disabled.""" @@ -1337,6 +1343,12 @@ class Debug: moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + enable_nan_tracker: bool = False + """If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model.""" + + nan_tracker_verbose: bool = False + """If True, print stats for every layer (very verbose output).""" + @dataclass class JobConfig: diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index da81d60c0c..05c345adab 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -185,6 +185,7 @@ def parallelize_deepseekv3( ep_degree=parallel_dims.ep, edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py index 338092fb7a..80a8bc8bc2 100644 --- a/torchtitan/models/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -147,6 +147,7 @@ def parallelize_gptoss( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, edp_mesh=edp_mesh, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5634070a62..b56e173ba8 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -215,6 +215,7 @@ def parallelize_llama( ep_degree=parallel_dims.ep, edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -343,6 +344,7 @@ def apply_fsdp( ep_degree: int = 1, edp_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, + disable_prefetch: bool = False, ): """ Apply data parallelism (via FSDP2) to the model. @@ -359,6 +361,7 @@ def apply_fsdp( - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + disable_prefetch (bool, optional): Whether to disable FSDP forward/backward prefetching. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) @@ -505,6 +508,11 @@ def apply_fsdp( if ep_degree == 1: return + # Skip prefetch setup if disabled + if disable_prefetch: + logger.info("FSDP prefetching is disabled") + return + # forward # pyrefly: ignore [not-callable] transformer_blocks = list(model.layers.values()) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index fd621e9883..3dad54d46a 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -163,6 +163,7 @@ def parallelize_qwen3( ep_degree=parallel_dims.ep, edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: From 375762b932dfac54d1d98928bc1c041fdcb7edf2 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 09:05:24 -0800 Subject: [PATCH 122/127] Add partial resharding support (fsdp_reshard_after_forward accepts int) Allow fsdp_reshard_after_forward to accept an integer N for partial resharding to N-GPU groups. This reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. Use N=8 for intra-node resharding (fast NVLink communication). N must be a factor of the FSDP shard world size. Example: fsdp_reshard_after_forward = 8 Changes: - config/job_config.py: Update type to Literal[...] | int - llama4/parallelize.py: Handle int values in apply_fsdp() --- torchtitan/config/job_config.py | 7 ++- torchtitan/models/llama4/infra/parallelize.py | 43 ++++++++++++------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index cc1e587598..3eb2c20238 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -506,19 +506,22 @@ class Parallelism: only `data_parallel_shard_degree` can be negative. 1 means disabled. """ - fsdp_reshard_after_forward: Literal["default", "always", "never"] = "default" + fsdp_reshard_after_forward: Literal["default", "always", "never"] | int = "default" """ `reshard_after_forward` specifies the policy for applying `reshard_after_forward` within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. - The supported policies include "default", "always" and "never": + The supported policies include "default", "always", "never", or an integer: - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + - integer N: Partially reshard to groups of N GPUs after forward. Must be a factor of + the FSDP shard world size. Use N=8 for intra-node resharding (reduces memory while + keeping communication fast via NVLink). This trades memory for communication. """ fsdp_disable_prefetch: bool = False diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index b56e173ba8..31b49af3fc 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -340,7 +340,7 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, - reshard_after_forward_policy: str = "default", + reshard_after_forward_policy: str | int = "default", ep_degree: int = 1, edp_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, @@ -356,11 +356,15 @@ def apply_fsdp( reduce_dtype (torch.dtype): The data type to use for reduction operations. pp_enabled (bool): Whether pipeline parallelism is enabled. cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. - reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". - Other options: "never", "always". + reshard_after_forward_policy (str | int, optional): The policy to use for resharding after forward pass. Defaults to "default". + String options: "never", "always", "default". - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + Integer option: N (e.g., 8) for partial resharding to N-GPU groups. + - Reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. + - Use N=8 for intra-node resharding (fast NVLink communication). + - N must be a factor of the FSDP shard world size. disable_prefetch (bool, optional): Whether to disable FSDP forward/backward prefetching. Defaults to False. """ @@ -369,19 +373,26 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - match reshard_after_forward_policy: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." - ) + # Handle integer reshard_after_forward (partial resharding to N-GPU groups) + if isinstance(reshard_after_forward_policy, int): + reshard_after_forward = reshard_after_forward_policy + logger.info( + f"Using partial reshard_after_forward={reshard_after_forward} (resharding to {reshard_after_forward}-GPU groups)" + ) + else: + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) if model.tok_embeddings is not None: # pyrefly: ignore [no-matching-overload] From 0a064294718612513dcce4ca7c402cfa905514fb Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 09:29:59 -0800 Subject: [PATCH 123/127] Add device mesh visualizer for distributed training debugging Standalone diagnostic tool that visualizes GPU allocation across parallelism dimensions (DP, PP, TP, CP, EP). Only runs on rank 0 at initialization - no impact on training performance. Features: - Mesh structure visualization - GPU allocation grid - Expert parallel group allocation - Context parallel group allocation - FSDP sharding visualization Integrated into train.py to automatically log at startup. --- torchtitan/tools/mesh_visualizer.py | 415 ++++++++++++++++++++++++++++ torchtitan/train.py | 4 + 2 files changed, 419 insertions(+) create mode 100644 torchtitan/tools/mesh_visualizer.py diff --git a/torchtitan/tools/mesh_visualizer.py b/torchtitan/tools/mesh_visualizer.py new file mode 100644 index 0000000000..0ba8fecb03 --- /dev/null +++ b/torchtitan/tools/mesh_visualizer.py @@ -0,0 +1,415 @@ +""" +Device Mesh Visualizer for Distributed Training + +Creates comprehensive visualization of how GPUs are allocated across +all parallelism dimensions: DP, PP, TP, CP, EP. +""" + +import os +from typing import Dict + +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.tools.logging import logger + + +def get_rank_info() -> Dict: + """Get current rank's information across all process groups.""" + info = { + "global_rank": dist.get_rank() if dist.is_initialized() else 0, + "world_size": dist.get_world_size() if dist.is_initialized() else 1, + "local_rank": int(os.environ.get("LOCAL_RANK", 0)), + "node_rank": int(os.environ.get("GROUP_RANK", os.environ.get("NODE_RANK", 0))), + } + return info + + +def visualize_mesh_structure( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a detailed text visualization of the device mesh structure. + + Args: + mesh: The DeviceMesh object + parallel_dims: ParallelDims object with all parallelism settings + rank: Current rank (only rank 0 prints full visualization) + + Returns: + String visualization of the mesh + """ + lines = [] + lines.append("=" * 100) + lines.append("DEVICE MESH VISUALIZATION") + lines.append("=" * 100) + + # Basic info + lines.append("\n[CLUSTER INFO]") + lines.append(f" Total GPUs: {parallel_dims.world_size}") + lines.append(f" Nodes: {parallel_dims.world_size // 8} (assuming 8 GPUs/node)") + + # Parallelism dimensions + lines.append("\n[PARALLELISM DIMENSIONS]") + lines.append(f" DP Replicate (HSDP): {parallel_dims.dp_replicate}") + lines.append(f" DP Shard (FSDP): {parallel_dims.dp_shard}") + lines.append(f" Context Parallel: {parallel_dims.cp}") + lines.append(f" Tensor Parallel: {parallel_dims.tp}") + lines.append(f" Pipeline Parallel: {parallel_dims.pp}") + lines.append(f" Expert Parallel: {parallel_dims.ep}") + lines.append(f" Expert TP: {parallel_dims.etp}") + + # Mesh structure + lines.append("\n[MESH STRUCTURE]") + lines.append(f" Mesh dim names: {mesh.mesh_dim_names}") + lines.append(f" Mesh shape: {mesh.mesh.shape}") + + # Log each dimension + for i, (name, size) in enumerate(zip(mesh.mesh_dim_names, mesh.mesh.shape)): + lines.append(f" Dim {i}: {name:20s} = {size}") + + # EP-specific derived dimensions + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append("\n[EXPERT PARALLEL DERIVED DIMENSIONS]") + lines.append(f" dp_shard_mod_ep (DP for non-experts): {dp_shard_mod_ep}") + lines.append(f" dp_shard_in_ep (DP within EP group): {dp_shard_in_ep}") + lines.append(f" ep_group_size (EP degree): {parallel_dims.ep}") + lines.append("") + lines.append(" Formula: dp_shard = dp_shard_mod_ep * dp_shard_in_ep") + lines.append( + f" {parallel_dims.dp_shard} = {dp_shard_mod_ep} * {dp_shard_in_ep}" + ) + lines.append("") + lines.append(" Formula: ep = dp_shard_in_ep * cp") + lines.append( + f" {parallel_dims.ep} = {dp_shard_in_ep} * {parallel_dims.cp}" + ) + + # Submesh info + lines.append("\n[SUBMESHES]") + + # Try to get submesh info + submesh_names = ["dp", "dp_shard_cp", "dp_cp", "ep", "cp", "tp", "pp"] + for name in submesh_names: + try: + submesh = mesh[name] + lines.append( + f" {name:15s}: size={submesh.size():4d}, dim_names={submesh.mesh_dim_names}" + ) + except (KeyError, RuntimeError): + pass + + return "\n".join(lines) + + +def visualize_gpu_allocation( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a grid visualization showing GPU allocation. + + For 16 nodes (128 GPUs) with EP=64, CP=8: + - Shows how each GPU maps to (dp_shard_mod_ep, dp_shard_in_ep, cp) coordinates + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("GPU ALLOCATION GRID") + lines.append("=" * 100) + + world_size = parallel_dims.world_size + num_nodes = world_size // 8 + + # For EP-enabled config + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append( + f"\nMesh: [{dp_shard_mod_ep}] x [{dp_shard_in_ep}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard_mod_ep] x [dp_shard_in_ep] x [cp]") + lines.append("") + + # Create mapping from global rank to mesh coordinates + lines.append("GPU -> Mesh Coordinate Mapping:") + lines.append("-" * 80) + lines.append( + f"{'Node':>6} | {'GPU':>4} | {'Rank':>5} | {'dp_mod_ep':>10} | {'dp_in_ep':>10} | {'cp':>4} | {'EP Group':>10}" + ) + lines.append("-" * 80) + + # The mesh is laid out as: dp_shard_mod_ep (slowest) x dp_shard_in_ep x cp (fastest) + for node in range(num_nodes): + for local_gpu in range(8): + global_rank = node * 8 + local_gpu + + # Compute mesh coordinates (assuming row-major ordering) + # Total size = dp_shard_mod_ep * dp_shard_in_ep * cp + cp_coord = global_rank % parallel_dims.cp + dp_in_ep_coord = (global_rank // parallel_dims.cp) % dp_shard_in_ep + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + + # EP group = dp_in_ep_coord * cp + cp_coord (within each dp_shard_mod_ep group) + ep_group = dp_in_ep_coord * parallel_dims.cp + cp_coord + + row = ( + f"{node:>6} | {local_gpu:>4} | {global_rank:>5} | " + f"{dp_mod_ep_coord:>10} | {dp_in_ep_coord:>10} | " + f"{cp_coord:>4} | {ep_group:>10}" + ) + lines.append(row) + + if node < num_nodes - 1: + lines.append("-" * 80) + else: + lines.append( + f"\nMesh: [{parallel_dims.dp_shard}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard] x [cp]") + + return "\n".join(lines) + + +def visualize_expert_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize which GPUs belong to which Expert Parallel group. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("EXPERT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.ep <= 1: + lines.append("Expert Parallel is disabled (EP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append(f"\nEP={parallel_dims.ep} experts distributed across GPUs") + lines.append( + f"Each EP group has {parallel_dims.ep} GPUs working on different experts" + ) + lines.append( + f"There are {dp_shard_mod_ep} such EP groups (for FSDP replication of experts)" + ) + lines.append("") + + # Group GPUs by their dp_shard_mod_ep coordinate + lines.append("EP Groups (GPUs that share the same set of experts):") + lines.append("-" * 80) + + for dp_mod_ep_idx in range(dp_shard_mod_ep): + # Find all ranks in this dp_shard_mod_ep group + ranks_in_group = [] + for global_rank in range(world_size): + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + if dp_mod_ep_coord == dp_mod_ep_idx: + ranks_in_group.append(global_rank) + + lines.append(f"\nDP_SHARD_MOD_EP group {dp_mod_ep_idx}:") + lines.append( + f" GPUs: {ranks_in_group[:16]}{'...' if len(ranks_in_group) > 16 else ''}" + ) + lines.append(f" Total: {len(ranks_in_group)} GPUs") + lines.append(" These GPUs have IDENTICAL expert parameters (FSDP sharded)") + + return "\n".join(lines) + + +def visualize_context_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize Context Parallel groups - GPUs that work on different parts of the sequence. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("CONTEXT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.cp <= 1: + lines.append("Context Parallel is disabled (CP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + cp = parallel_dims.cp + + lines.append(f"\nCP={cp} - Each sequence is split into {cp} chunks") + lines.append( + "GPUs with the same (dp_shard, ep) coordinates but different cp coordinates" + ) + lines.append("work on different parts of the same sequence.") + lines.append("") + + # Show a few example CP groups + lines.append("Example CP groups (first few):") + lines.append("-" * 80) + + num_cp_groups = world_size // cp + for cp_group_idx in range(min(4, num_cp_groups)): + ranks_in_group = [cp_group_idx * cp + i for i in range(cp)] + lines.append(f"\nCP group {cp_group_idx}:") + lines.append(f" GPUs: {ranks_in_group}") + lines.append(f" These {cp} GPUs process different chunks of the same sequence") + + if num_cp_groups > 4: + lines.append(f"\n... and {num_cp_groups - 4} more CP groups") + + return "\n".join(lines) + + +def visualize_fsdp_sharding( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize FSDP sharding - which GPUs share which parameters. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("FSDP SHARDING VISUALIZATION") + lines.append("=" * 100) + + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + + lines.append("\n[NON-EXPERT PARAMETERS (Attention, Embeddings, etc.)]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + lines.append( + f" All-gather buffer size per param: original_size / {dp_shard_cp_size}" + ) + + lines.append("\n[EXPERT PARAMETERS (MoE experts)]") + lines.append(" FSDP mesh: dp_shard_mod_ep") + lines.append(f" FSDP group size: {dp_shard_mod_ep} GPUs") + lines.append( + f" Each expert's parameters are sharded across {dp_shard_mod_ep} GPUs" + ) + lines.append( + f" All-gather buffer size per expert param: original_size / {dp_shard_mod_ep}" + ) + + lines.append("\n[MEMORY IMPLICATIONS]") + lines.append( + f" Non-expert params: sharded {dp_shard_cp_size}x -> small per-GPU footprint" + ) + lines.append( + f" Expert params: sharded only {dp_shard_mod_ep}x -> larger per-GPU footprint" + ) + lines.append(" ") + lines.append(" As DP increases:") + lines.append( + " - dp_shard_cp increases -> non-expert params get more sharded" + ) + lines.append( + " - dp_shard_mod_ep increases -> expert params get more sharded" + ) + lines.append( + " - BUT: all-gather/reduce-scatter buffers scale with group size!" + ) + + else: + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + lines.append("\n[ALL PARAMETERS]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + + return "\n".join(lines) + + +def create_full_visualization( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """Create a comprehensive visualization of the entire mesh structure.""" + parts = [ + visualize_mesh_structure(mesh, parallel_dims, rank), + visualize_gpu_allocation(mesh, parallel_dims, rank), + visualize_expert_parallel_groups(mesh, parallel_dims, rank), + visualize_context_parallel_groups(mesh, parallel_dims, rank), + visualize_fsdp_sharding(mesh, parallel_dims, rank), + ] + + full_viz = "\n".join(parts) + full_viz += "\n" + "=" * 100 + full_viz += "\nEND OF DEVICE MESH VISUALIZATION" + full_viz += "\n" + "=" * 100 + + return full_viz + + +def log_mesh_visualization(mesh: DeviceMesh, parallel_dims): + """Log the full mesh visualization (only on rank 0).""" + rank = dist.get_rank() if dist.is_initialized() else 0 + + if rank == 0: + viz = create_full_visualization(mesh, parallel_dims, rank) + # Log each line separately for better formatting + for line in viz.split("\n"): + logger.info(f"[MESH-VIZ] {line}") diff --git a/torchtitan/train.py b/torchtitan/train.py index bb9060bcf0..2ff6b3a49c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -34,6 +34,7 @@ from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.mesh_visualizer import log_mesh_visualization from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -94,6 +95,9 @@ def __init__(self, job_config: JobConfig): # init distributed and build meshes self.parallel_dims = parallel_dims = self.init_distributed() + # Log mesh visualization for debugging distributed setup (rank 0 only) + log_mesh_visualization(parallel_dims.world_mesh, parallel_dims) + if parallel_dims.dp_enabled: batch_mesh = parallel_dims.get_mesh("batch") batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() From fe8d1f0c858956187c90500ff9ab7cb3fd9a114b Mon Sep 17 00:00:00 2001 From: emozilla Date: Sun, 1 Feb 2026 03:56:54 +0000 Subject: [PATCH 124/127] add option to filter data when preprocessing by a specific string --- scripts/preprocess_data.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index 6184634a8a..7b85cb1efb 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -709,7 +709,29 @@ def _get_conversation_len(x): return len(x["messages"]) return 0 + len_before = len(dataset) dataset = dataset.filter(lambda x: _get_conversation_len(x) > 3) + print(f"Filtered by multiturn: {len_before} -> {len(dataset)} samples") + if args.required_text: + def _contains_required_text(x): + if args.chat: + if "conversations" in x: + messages = x["conversations"] + elif "messages" in x: + messages = x["messages"] + else: + return False + for message in messages: + content = message.get("content") or message.get("value") or "" + if args.required_text in content: + return True + return False + else: + return args.required_text in x.get("text", "") + + len_before = len(dataset) + dataset = dataset.filter(_contains_required_text) + print(f"Filtered by required_text '{args.required_text}': {len_before} -> {len(dataset)} samples") original_column_names = list(dataset.features.keys()) dataset = dataset.map( @@ -960,6 +982,7 @@ def _add_position_ids_and_seq_lengths(sample): parser.add_argument("--limit", type=int) parser.add_argument("--chat", action="store_true") parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--required-text", type=str) parser.add_argument("--pack-to-sequence-length", type=int) parser.add_argument( "--epochs", From f50b8045a6c6ef4f2b4d3fcaea87371eb2c20749 Mon Sep 17 00:00:00 2001 From: emozilla Date: Sun, 1 Feb 2026 03:57:10 +0000 Subject: [PATCH 125/127] add kimi_k2_sft --- torchtitan/models/deepseek_v3/__init__.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 9142acc839..d6ae7e2017 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -184,6 +184,36 @@ rope_factor=32.0, beta_fast=1, ), + "kimi_k2_sft": DeepSeekV3ModelArgs( + vocab_size=163840, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + # n_layers=9, #smaller for testing + n_layers=61, + n_dense_layers=1, + n_heads=64, + norm_eps=1e-6, + moe_args=MoEArgs( + num_experts=384, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.827, + score_before_experts=False, + ), + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + attn_type="flex", + attn_mask_type="block_causal_by_sequence_lengths", + rope_theta=50000.0, + rope_factor=32.0, + beta_fast=1, + ), } From ed6b7536c1bc3b5558516a44db23d4521ef6803b Mon Sep 17 00:00:00 2001 From: emozilla Date: Sun, 1 Feb 2026 05:16:12 +0000 Subject: [PATCH 126/127] fix wrong arg used for --push-to-hub --- scripts/preprocess_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index 7b85cb1efb..aca7d0cd47 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -936,7 +936,7 @@ def _add_position_ids_and_seq_lengths(sample): dataset.save_to_disk(args.save_to_disk) if args.push_to_hub: print(f"Pushing to Hugging Face repo {args.push_to_hub}") - dataset.push_to_hub(args.save_to_disk, private=True) + dataset.push_to_hub(args.push_to_hub, private=True) example = dataset[0] From 7f6f3a3b70feeda7907fa737d216927d9da714f3 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 4 Feb 2026 10:53:48 -0800 Subject: [PATCH 127/127] fix attention args, add kimi_k2_ep64_cp1_seq24k_lbs1 160 tps config --- torchtitan/models/attention.py | 6 +- .../kimi_k2_ep64_cp1_seq24k_lbs1.toml | 74 +++++++++++++++++++ 2 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index fd399a55f5..2493323ab7 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -110,8 +110,9 @@ class FlexAttentionWrapper(torch.nn.Module): 2) Being a wrapper allows us to apply _ContextParallel to it. Note: - The forward function must have q, k, v as the first three arguments, and - block_mask as a keyword argument to be compatible with _ContextParallel. + The forward function accepts q, k, v as the first three arguments, followed by + optional arguments (score_mod, block_mask, scale, return_lse) that can be passed + either positionally or as keywords to be compatible with _ContextParallel. """ _compiled_flex_attn: ClassVar[Callable] = torch.compile( @@ -130,7 +131,6 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - *, score_mod: _score_mod_signature | None = None, block_mask: BlockMask | None = None, scale: float | None = None, diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml new file mode 100644 index 0000000000..f6328e98b5 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml @@ -0,0 +1,74 @@ +# ============================================================================= +# Kimi K2 - Best Configuration (D4) +# ============================================================================= +# Run: D4 | Job ID: 2889 +# Performance: 160 TPS (BEST OVERALL) | MFU: 6.04% | Memory: 76.99GiB (97.06%) +# +# Parameters: +# - EP=64, CP=1, SEQ=24K (24576), LBS=1 +# - Nodes: 8 (64 GPUs) +# - DP_replicate=1, DP_shard=1 (auto-calculated) +# +# Source Files: +# - Config: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/configs/D4_ep64_seq24k_lbs1.toml +# - SLURM: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/scripts/launch_D4_ep64_seq24k_lbs1.slurm +# - Log: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/results/D4_ep64_seq24k_lbs1_2889.out +# - Err: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/results/D4_ep64_seq24k_lbs1_2889.err +# +# Reference: /home/phuc/worklogs/2026-02-03/sweep_ep_cp_upstream_branch.md +# ============================================================================= + +[job] +dump_folder = "./outputs/ep_cp_sweep/D4_ep64_seq24k_lbs1" +description = "EP/CP Sweep D4: EP=64 CP=1 SEQ=24K LBS=1" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 1 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 1 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800