Skip to content

Commit 83c87b9

Browse files
committed
feat: drafted pp e2e test for fwd/bwd pass
1 parent d9f63c1 commit 83c87b9

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

tests/fsdp2_parallelization/pipeline_parallelism/__init__.py

Whitespace-only changes.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
settings:
2+
experiment_id: ${modalities_env:experiment_id}
3+
config_file_path: ${modalities_env:config_file_path}
4+
referencing_keys:
5+
sample_key: input_ids
6+
target_key: target_ids
7+
prediction_key: logits
8+
cuda_env:
9+
local_rank: ${cuda_env:LOCAL_RANK}
10+
global_rank: ${cuda_env:RANK}
11+
world_size: ${cuda_env:WORLD_SIZE}
12+
step_profile:
13+
gradient_accumulation_steps: 1
14+
local_train_micro_batch_size: 2
15+
sequence_length: 256
16+
17+
loss_fn:
18+
component_key: loss
19+
variant_key: clm_cross_entropy_loss
20+
config:
21+
target_key: ${settings.referencing_keys.target_key}
22+
prediction_key: ${settings.referencing_keys.prediction_key}
23+
24+
device_mesh:
25+
component_key: device_mesh
26+
variant_key: default
27+
config:
28+
device_type: cuda
29+
data_parallel_replicate_degree: 1
30+
pipeline_parallel_degree: 2
31+
data_parallel_shard_degree: -1
32+
world_size: ${settings.cuda_env.world_size}
33+
34+
initialized_model:
35+
component_key: model
36+
variant_key: model_initialized
37+
config:
38+
model:
39+
component_key: pipeline
40+
variant_key: selector
41+
config:
42+
pipeline:
43+
instance_key: scheduled_pipeline
44+
pass_type: BY_REFERENCE
45+
selection_type: MODEL
46+
model_initializer:
47+
component_key: model_initialization
48+
variant_key: composed
49+
config:
50+
model_type: gpt2
51+
weight_init_type: scaled
52+
mean: 0.0
53+
std: 0.02
54+
num_layers: ${model_raw.config.n_layer}
55+
56+
scheduled_pipeline:
57+
component_key: pipeline
58+
variant_key: scheduled
59+
config:
60+
loss_fn:
61+
instance_key: loss_fn
62+
pass_type: BY_REFERENCE
63+
pp_schedule_name: gpipe
64+
batch_size: ${settings.step_profile.local_train_micro_batch_size}
65+
microbatch_size: 1
66+
pp_degree: ${device_mesh.config.pipeline_parallel_degree}
67+
pipeline:
68+
component_key: pipeline
69+
variant_key: builder
70+
config:
71+
stage:
72+
component_key: pipeline
73+
variant_key: selector
74+
config:
75+
pipeline:
76+
instance_key: staged_pipeline
77+
pass_type: BY_REFERENCE
78+
selection_type: STAGE
79+
model:
80+
instance_key: fsdp_model
81+
pass_type: BY_REFERENCE
82+
83+
fsdp_model:
84+
component_key: model
85+
variant_key: fsdp2_wrapped
86+
config:
87+
model:
88+
instance_key: model_part
89+
pass_type: BY_REFERENCE
90+
device_mesh:
91+
instance_key: device_mesh
92+
pass_type: BY_REFERENCE
93+
mixed_precision_settings:
94+
param_dtype: BF_16
95+
reduce_dtype: BF_16
96+
block_names: [GPT2Block]
97+
98+
model_part:
99+
component_key: pipeline
100+
variant_key: selector
101+
config:
102+
pipeline:
103+
instance_key: staged_pipeline
104+
pass_type: BY_REFERENCE
105+
selection_type: MODEL
106+
107+
staged_pipeline:
108+
component_key: pipeline
109+
variant_key: staged
110+
config:
111+
whole_model:
112+
instance_key: model_raw
113+
pass_type: BY_REFERENCE
114+
stages_generator:
115+
component_key: stages_generator
116+
variant_key: gpt2_stages_generator
117+
config:
118+
num_model_layers: ${model_raw.config.n_layer}
119+
input_layer_equivalence: 1
120+
output_layer_equivalence: 1
121+
device_mesh:
122+
instance_key: device_mesh
123+
pass_type: BY_REFERENCE
124+
local_rank: ${settings.cuda_env.local_rank}
125+
pp_schedule_name: gpipe
126+
num_layers_per_stage: 2
127+
128+
model_raw:
129+
component_key: model
130+
variant_key: gpt2
131+
config:
132+
use_meta_device: true
133+
use_weight_tying: false
134+
sample_key: ${settings.referencing_keys.sample_key}
135+
poe_type: NOPE
136+
sequence_length: ${settings.step_profile.sequence_length}
137+
prediction_key: ${loss_fn.config.prediction_key}
138+
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
139+
n_layer: 2
140+
n_head_q: 8
141+
n_head_kv: 4
142+
ffn_hidden: 128
143+
n_embd: 128
144+
dropout: 0.0
145+
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
146+
attention_config:
147+
qkv_transforms:
148+
- type_hint: RotaryTransform
149+
config:
150+
n_embd: ${model_raw.config.n_embd}
151+
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
152+
seq_length_dim: -2
153+
base_freq: 10000
154+
attention_implementation: manual
155+
activation_type: swiglu
156+
attention_norm_config:
157+
norm_type: layer_norm
158+
config:
159+
normalized_shape: ${model_raw.config.n_embd}
160+
eps: 1e-5
161+
ffn_norm_config:
162+
norm_type: layer_norm
163+
config:
164+
normalized_shape: ${model_raw.config.n_embd}
165+
eps: 1e-5
166+
lm_head_norm_config:
167+
norm_type: layer_norm
168+
config:
169+
normalized_shape: ${model_raw.config.n_embd}
170+
eps: 1e-5
171+
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
5+
import pytest
6+
import torch
7+
import torch.multiprocessing as mp
8+
import yaml
9+
from pydantic import BaseModel
10+
11+
from modalities.__main__ import Main
12+
from modalities.config.config import ProcessGroupBackendType
13+
from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType
14+
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
15+
16+
17+
@pytest.fixture
18+
def temp_file_path() -> Path:
19+
# Create a NamedTemporaryFile that persists after closing (delete=False)
20+
with tempfile.NamedTemporaryFile(delete=False) as tf:
21+
file_path = tf.name
22+
try:
23+
yield Path(file_path)
24+
finally:
25+
# Clean up the file after the test
26+
if os.path.exists(file_path):
27+
os.remove(file_path)
28+
29+
30+
class ComponentsInstantiationModel(BaseModel):
31+
initialized_model: PydanticFSDP2ModuleType
32+
scheduled_pipeline: PydanticPipelineType
33+
34+
35+
@pytest.mark.skipif(
36+
torch.cuda.device_count() < 8,
37+
reason="This test requires 8 GPUs",
38+
)
39+
class TestPipelineParallelism:
40+
def _get_tmp_sharding_config_path(
41+
self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path
42+
) -> Path:
43+
working_dir = Path(os.path.dirname(__file__))
44+
config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml"
45+
46+
with open(config_file_path, "r") as file:
47+
config_string = file.read()
48+
config_dict = yaml.safe_load(config_string)
49+
config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = sharding_degree
50+
config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree
51+
config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree
52+
53+
# save to temporary file
54+
with open(temp_file_path, "w") as file:
55+
yaml.dump(config_dict, file)
56+
57+
return temp_file_path
58+
59+
def _get_components(self, config_file_path: Path) -> ComponentsInstantiationModel:
60+
main_obj = Main(config_file_path)
61+
components: ComponentsInstantiationModel = main_obj.build_components(
62+
components_model_type=ComponentsInstantiationModel
63+
)
64+
return components
65+
66+
@pytest.mark.parametrize(
67+
"sharding_degree, tp_degree, pp_degree, world_size",
68+
[
69+
(2, 1, 2, 4),
70+
# (2, 1, 4, 8),
71+
# (2, 2, 2, 8), # TODO need to support this case
72+
],
73+
)
74+
def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path):
75+
tmp_sharding_config_path = self._get_tmp_sharding_config_path(
76+
sharding_degree=sharding_degree,
77+
tp_degree=tp_degree,
78+
pp_degree=pp_degree,
79+
temp_file_path=temp_file_path,
80+
)
81+
mp.spawn(
82+
self._test_pp_impl,
83+
args=(world_size, sharding_degree, tmp_sharding_config_path),
84+
nprocs=world_size,
85+
join=True,
86+
)
87+
88+
def _test_pp_impl(
89+
self,
90+
process_id: int,
91+
world_size: int,
92+
sharding_degree: int,
93+
gpt2_model_config_path: Path,
94+
):
95+
# wraps the actual test function to be able to run it in a distributed multiprocessing setup
96+
with MultiProcessingCudaEnv(
97+
process_group_backend=ProcessGroupBackendType.nccl,
98+
global_rank=process_id,
99+
local_rank=process_id,
100+
world_size=world_size,
101+
rdvz_port=22356,
102+
):
103+
self._get_components(gpt2_model_config_path)
104+
pass

0 commit comments

Comments
 (0)