Skip to content

Commit d1cc216

Browse files
committed
fix: fix gpt2 and llama multi-node training
1 parent d0d26ec commit d1cc216

File tree

11 files changed

+53
-93
lines changed

11 files changed

+53
-93
lines changed

example/gpt2/main.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void Train(const nn::parallel::Rank &rank) {
107107

108108
int ddp_world_size = global::GetDataParallelSize();
109109
int tp_world_size = global::GetTensorParallelSize();
110-
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0;
110+
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1;
111111
int pp_world_size = global::GetPipelineParallelSize();
112112

113113
if (FLAGS_sequence_parallel) {
@@ -129,21 +129,21 @@ void Train(const nn::parallel::Rank &rank) {
129129
if (ddp_world_size > 1) {
130130
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
131131
GetDataParallelGroupRanks(rank.GlobalRank()));
132-
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
132+
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
133133
}
134134

135135
if (tp_world_size > 1) {
136136
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
137137
GetTensorParallelGroupRanks(rank.GlobalRank()));
138-
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
138+
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
139139
// NOTE(zbl): Reserved for VocabParallelEmbedding
140140
nn::parallel::tp_rank = tp_rank;
141141
}
142142

143143
if (pp_world_size > 1) {
144-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
145-
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
146-
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
144+
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
145+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
146+
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
147147

148148
nn::parallel::pp_rank = pp_rank;
149149
}
@@ -184,7 +184,7 @@ void Train(const nn::parallel::Rank &rank) {
184184
} else if (FLAGS_dtype == kDtypeBF16) {
185185
dtype = DataType::kBFLOAT16;
186186
} else {
187-
LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported.";
187+
LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported.";
188188
}
189189

190190
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
@@ -225,7 +225,7 @@ void Train(const nn::parallel::Rank &rank) {
225225
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
226226
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
227227
loss_fn->To(device);
228-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";
228+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";
229229

230230
if (pp_world_size > 1) {
231231
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};
@@ -285,23 +285,23 @@ void Train(const nn::parallel::Rank &rank) {
285285
x = std::make_shared<Tensor>(x->To(device));
286286
y = std::make_shared<Tensor>(y->To(device));
287287

288-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
288+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
289289
// (bs, seq_len, vocab_size)
290290
auto logits = model->Forward({x, y})[0];
291-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
291+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
292292
auto loss = loss_fn->Forward({logits, y})[0];
293293
loss = loss / grad_accum_steps;
294294

295295
// disable autocast for the current step (backward is not under autocast)
296296
autocast_guard.Disable();
297297

298-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
298+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
299299

300300
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
301301
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
302-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
302+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
303303
loss->Backward();
304-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
304+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
305305
}
306306

307307
optimizer.Step();

example/gpt2/net.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <filesystem>
66
#include <fstream>
77
#include <random>
8-
#include <stdexcept>
98
#include <string>
109
#include <tuple>
1110

@@ -239,7 +238,8 @@ GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
239238
auto x1 = x[0];
240239
const auto device = x1->GetDevice();
241240

242-
const auto t = x1->Dims()[1]; // T
241+
const auto t
242+
= x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len
243243
CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only "
244244
<< config_.block_size;
245245
// forward the GPT2 model itself
@@ -252,8 +252,8 @@ GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
252252
int tp_rank = 0;
253253
if (tp_world_size > 1) {
254254
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
255-
nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank()));
256-
tp_rank = tp_group->GetGroupRank(device->rank().thread_rank());
255+
nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank()));
256+
tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank());
257257
}
258258
int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t;
259259
int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0;
@@ -386,7 +386,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
386386
} else if (pp_size > 1 && is_last_stage) {
387387
auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2::kLMHeadLayerName,
388388
nn::parallel::ColumnParallelLinear::kParamWeightName)];
389-
ifs.read(reinterpret_cast<char *>(lm_head_weight->DataPtr()), lm_head_weight->SizeInBytes());
389+
ReadMatrixRowShardFloat(ifs, static_cast<float *>(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start,
390+
vpp);
390391
} else {
391392
size_t wte_bytes = vocab_size * n_embd * sizeof(float);
392393
ifs.seekg(wte_bytes, std::ios::cur);

example/llama3/main.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void Train(const nn::parallel::Rank &rank) {
8989

9090
int ddp_world_size = global::GetDataParallelSize();
9191
int tp_world_size = global::GetTensorParallelSize();
92-
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0;
92+
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 1;
9393
int pp_world_size = global::GetPipelineParallelSize();
9494

9595
if (FLAGS_sequence_parallel) {
@@ -111,21 +111,21 @@ void Train(const nn::parallel::Rank &rank) {
111111
if (ddp_world_size > 1) {
112112
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
113113
GetDataParallelGroupRanks(rank.GlobalRank()));
114-
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
114+
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
115115
}
116116

117117
if (tp_world_size > 1) {
118118
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
119119
GetTensorParallelGroupRanks(rank.GlobalRank()));
120-
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
120+
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
121121
// NOTE(zbl): Reserved for VocabParallelEmbedding
122122
nn::parallel::tp_rank = tp_rank;
123123
}
124124

125125
if (pp_world_size > 1) {
126126
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
127127
GetPipelineParallelGroupRanks(rank.GlobalRank()));
128-
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
128+
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
129129

130130
nn::parallel::pp_rank = pp_rank;
131131
}
@@ -156,15 +156,15 @@ void Train(const nn::parallel::Rank &rank) {
156156

157157
model->To(device);
158158

159-
LOG(INFO) << "Rank " << rank.thread_rank() << ": Model loaded to device.";
159+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";
160160

161161
DataType dtype;
162162
if (FLAGS_dtype == kDtypeFP32) {
163163
dtype = DataType::kFLOAT32;
164164
} else if (FLAGS_dtype == kDtypeBF16) {
165165
dtype = DataType::kBFLOAT16;
166166
} else {
167-
LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported.";
167+
LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported.";
168168
}
169169

170170
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
@@ -204,10 +204,13 @@ void Train(const nn::parallel::Rank &rank) {
204204
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
205205
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
206206
loss_fn->To(device);
207-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";
207+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";
208208

209209
if (pp_world_size > 1) {
210-
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};
210+
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
211+
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
212+
auto shapes = std::vector<std::vector<int64_t>>{
213+
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};
211214

212215
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
213216
pp_rank, std::make_shared<optimizers::Adam>(optimizer),
@@ -262,23 +265,23 @@ void Train(const nn::parallel::Rank &rank) {
262265
x = std::make_shared<Tensor>(x->To(device));
263266
y = std::make_shared<Tensor>(y->To(device));
264267

265-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
268+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start forward";
266269
// (bs, seq_len, vocab_size)
267270
auto logits = model->Forward({x, y})[0];
268-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
271+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
269272
auto loss = loss_fn->Forward({logits, y})[0];
270273
loss = loss / grad_accum_steps;
271274

272275
// disable autocast for the current step (backward is not under autocast)
273276
autocast_guard.Disable();
274277

275-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
278+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
276279

277280
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
278281
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
279-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
282+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
280283
loss->Backward();
281-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
284+
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
282285
}
283286

284287
optimizer.Step();

example/llama3/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <fstream>
77
#include <memory>
88
#include <random>
9-
#include <set>
109
#include <string>
1110
#include <unordered_map>
1211
#include <vector>
@@ -364,7 +363,8 @@ std::vector<std::shared_ptr<Tensor>> LLaMA3::Forward(const std::vector<std::shar
364363
// (bs, seq_len)
365364
auto x1 = x[0];
366365
const auto device = x1->GetDevice();
367-
const auto t = x1->Dims()[1]; // seq_len
366+
const auto t
367+
= x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len
368368
CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only "
369369
<< config_.block_size;
370370

infini_train/include/nn/parallel/global.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class GlobalEnv {
4343

4444
int tensor_parallel_size() const;
4545

46+
int sequence_parallel_size() const;
47+
4648
bool sequence_parallel_enabled() const;
4749

4850
int data_parallel_size() const;
@@ -94,6 +96,7 @@ inline int GetGlobalProcRank() { return GlobalEnv::Instance().global_proc_rank()
9496
inline int GetLocalProcRank() { return GlobalEnv::Instance().local_proc_rank(); }
9597

9698
inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_parallel_size(); }
99+
inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_parallel_size(); }
97100
inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); }
98101
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }
99102
inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); }

infini_train/include/nn/parallel/process_group.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ProcessGroup {
3535

3636
~ProcessGroup();
3737

38-
int GetGroupRank(int thread_rank) const;
38+
int GetGroupRank(int global_rank) const;
3939

4040
// Communication operations
4141
void AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
@@ -63,8 +63,6 @@ class ProcessGroup {
6363
// Async communication functions
6464
std::shared_ptr<Work> AllReduceAsync(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
6565

66-
void Barrier() const;
67-
6866
private:
6967
void InitSingleProcess(const std::vector<int> &ranks);
7068

@@ -79,7 +77,7 @@ class ProcessGroup {
7977

8078
std::unordered_map<const Device *, ncclComm_t> device_comm_map_;
8179
std::unordered_map<const Device *, cudaStream_t> device_stream_map_;
82-
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank
80+
std::unordered_map<int, int> global_group_rank_map_; // global_rank : group_rank
8381

8482
int world_size_ = 0;
8583

infini_train/src/nn/parallel/global.cc

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <cstdlib>
44
#include <format>
5-
#include <nccl.h>
65
#include <string>
76

87
#include "glog/logging.h"
@@ -152,6 +151,11 @@ int GlobalEnv::tensor_parallel_size() const {
152151
return tensor_parallel_size_;
153152
}
154153

154+
int GlobalEnv::sequence_parallel_size() const {
155+
CHECK(initialized_) << "GlobalEnv is not initialized!";
156+
return sequence_parallel_enabled_ ? tensor_parallel_size_ : 1;
157+
}
158+
155159
bool GlobalEnv::sequence_parallel_enabled() const {
156160
CHECK(initialized_) << "GlobalEnv is not initialized!";
157161
return sequence_parallel_enabled_;
@@ -186,39 +190,6 @@ inline int NumGroups(const Layout &L, Axis target) {
186190
}
187191
} // namespace
188192

189-
inline void AppendAxisGroups(std::ostringstream &oss, const Layout &L, Axis target) {
190-
const int num_groups = NumGroups(L, target);
191-
const auto name = AxisName(target);
192-
oss << std::format("[{}] size={}, num_groups={}\n", name, L.sizes[target], num_groups);
193-
194-
for (int dp = 0; dp < (target == DP ? 1 : L.sizes[DP]); ++dp) {
195-
for (int tp = 0; tp < (target == TP ? 1 : L.sizes[TP]); ++tp) {
196-
for (int pp = 0; pp < (target == PP ? 1 : L.sizes[PP]); ++pp) {
197-
const int gid = L.GroupId(target, dp, tp, pp);
198-
auto ranks = L.GroupRanks(target, dp, tp, pp);
199-
std::sort(ranks.begin(), ranks.end());
200-
201-
auto dp_size_str = (target == DP) ? "-" : std::to_string(dp);
202-
auto tp_size_str = (target == TP) ? "-" : std::to_string(tp);
203-
auto pp_size_str = (target == PP) ? "-" : std::to_string(pp);
204-
205-
std::string ranks_str;
206-
ranks_str.reserve(ranks.size() * 4);
207-
208-
for (size_t i = 0; i < ranks.size(); ++i) {
209-
if (i > 0) {
210-
ranks_str += ", ";
211-
}
212-
ranks_str += std::to_string(ranks[i]);
213-
}
214-
215-
oss << std::format(" - {} {} (dp={}, tp={}, pp={}): [{}]\n", name, gid, dp_size_str, tp_size_str,
216-
pp_size_str, ranks_str);
217-
}
218-
}
219-
}
220-
}
221-
222193
std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) {
223194
std::ostringstream oss;
224195
oss << std::format("\n=== Parallel Communication Groups ===\n"

infini_train/src/nn/parallel/pp/pipeline_schedule.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h"
33

44
#include <cstddef>
5+
#include <cstdint>
56
#include <memory>
67
#include <vector>
78

infini_train/src/nn/parallel/process_group.cc

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void ProcessGroup::InitSingleProcess(const std::vector<int> &ranks) {
127127
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]);
128128
devices_.push_back(device);
129129
device_comm_map_[device] = comms_[i];
130-
thread_group_rank_map_[device->rank().thread_rank()] = i;
130+
global_group_rank_map_[device->rank().GlobalRank()] = i;
131131
}
132132
}
133133

@@ -162,7 +162,7 @@ void ProcessGroup::InitMultiProcess(const std::vector<int> &ranks) {
162162
comms_.push_back(comm);
163163

164164
auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i);
165-
thread_group_rank_map_[device->rank().thread_rank()] = group_rank;
165+
global_group_rank_map_[device->rank().GlobalRank()] = group_rank;
166166
devices_.push_back(device);
167167
device_comm_map_[device] = comm;
168168
}
@@ -183,7 +183,7 @@ void ProcessGroup::InitStreams() {
183183
}
184184
}
185185

186-
int ProcessGroup::GetGroupRank(int thread_rank) const { return thread_group_rank_map_.at(thread_rank); }
186+
int ProcessGroup::GetGroupRank(int global_rank) const { return global_group_rank_map_.at(global_rank); }
187187

188188
void ProcessGroup::AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const {
189189
void *buffer = tensor->DataPtr();
@@ -475,21 +475,6 @@ std::shared_ptr<Work> ProcessGroup::AllReduceAsync(const std::shared_ptr<Tensor>
475475
return std::move(work);
476476
}
477477

478-
void ProcessGroup::Barrier() const {
479-
// NOTE(dcj): use ncclAllreduce to barrier all processes before destroying the communicators
480-
// FIXME(dcj): should only call by one rank
481-
int dummy = 1;
482-
std::vector<int> results(1, 0);
483-
484-
NCCL_CHECK(ncclGroupStart());
485-
for (const auto &device : devices_) {
486-
device->SetDevice();
487-
auto comm = device_comm_map_.at(device);
488-
auto cuda_dev = dynamic_cast<const CudaDevice *>(device);
489-
NCCL_CHECK(ncclAllReduce(&dummy, &dummy, 1, ncclInt, ncclSum, comm, cuda_dev->Stream()));
490-
}
491-
NCCL_CHECK(ncclGroupEnd());
492-
}
493478
#endif
494479

495480
ProcessGroupFactory *ProcessGroupFactory::Instance() {

0 commit comments

Comments
 (0)