@@ -89,7 +89,7 @@ void Train(const nn::parallel::Rank &rank) {
8989
9090 int ddp_world_size = global::GetDataParallelSize ();
9191 int tp_world_size = global::GetTensorParallelSize ();
92- int sp_world_size = global::GetSequenceParallelEnabled () ? tp_world_size : 0 ;
92+ int sp_world_size = global::GetSequenceParallelEnabled () ? tp_world_size : 1 ;
9393 int pp_world_size = global::GetPipelineParallelSize ();
9494
9595 if (FLAGS_sequence_parallel) {
@@ -111,21 +111,21 @@ void Train(const nn::parallel::Rank &rank) {
111111 if (ddp_world_size > 1 ) {
112112 ddp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (GetDataParallelProcessGroupName (rank.GlobalRank ()),
113113 GetDataParallelGroupRanks (rank.GlobalRank ()));
114- ddp_rank = ddp_pg->GetGroupRank (rank.thread_rank ());
114+ ddp_rank = ddp_pg->GetGroupRank (rank.GlobalRank ());
115115 }
116116
117117 if (tp_world_size > 1 ) {
118118 tp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (GetTensorParallelProcessGroupName (rank.GlobalRank ()),
119119 GetTensorParallelGroupRanks (rank.GlobalRank ()));
120- tp_rank = tp_pg->GetGroupRank (rank.thread_rank ());
120+ tp_rank = tp_pg->GetGroupRank (rank.GlobalRank ());
121121 // NOTE(zbl): Reserved for VocabParallelEmbedding
122122 nn::parallel::tp_rank = tp_rank;
123123 }
124124
125125 if (pp_world_size > 1 ) {
126126 pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (GetPipelineParallelProcessGroupName (rank.GlobalRank ()),
127127 GetPipelineParallelGroupRanks (rank.GlobalRank ()));
128- pp_rank = pp_pg->GetGroupRank (rank.thread_rank ());
128+ pp_rank = pp_pg->GetGroupRank (rank.GlobalRank ());
129129
130130 nn::parallel::pp_rank = pp_rank;
131131 }
@@ -156,15 +156,15 @@ void Train(const nn::parallel::Rank &rank) {
156156
157157 model->To (device);
158158
159- LOG (INFO) << " Rank " << rank.thread_rank () << " : Model loaded to device." ;
159+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : Model loaded to device." ;
160160
161161 DataType dtype;
162162 if (FLAGS_dtype == kDtypeFP32 ) {
163163 dtype = DataType::kFLOAT32 ;
164164 } else if (FLAGS_dtype == kDtypeBF16 ) {
165165 dtype = DataType::kBFLOAT16 ;
166166 } else {
167- LOG (FATAL) << " Rank " << rank.thread_rank () << " : Datatype " << FLAGS_dtype << " not supported." ;
167+ LOG (FATAL) << " Rank " << rank.GlobalRank () << " : Datatype " << FLAGS_dtype << " not supported." ;
168168 }
169169
170170 // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
@@ -204,10 +204,13 @@ void Train(const nn::parallel::Rank &rank) {
204204 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
205205 : std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
206206 loss_fn->To (device);
207- LOG (INFO) << " Rank " << rank.thread_rank () << " : start training" ;
207+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : start training" ;
208208
209209 if (pp_world_size > 1 ) {
210- auto shapes = std::vector<std::vector<int64_t >>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd }};
210+ // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
211+ // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
212+ auto shapes = std::vector<std::vector<int64_t >>{
213+ {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd }};
211214
212215 model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
213216 pp_rank, std::make_shared<optimizers::Adam>(optimizer),
@@ -262,23 +265,23 @@ void Train(const nn::parallel::Rank &rank) {
262265 x = std::make_shared<Tensor>(x->To (device));
263266 y = std::make_shared<Tensor>(y->To (device));
264267
265- LOG (INFO) << " Rank " << rank.thread_rank () << " : start forward" ;
268+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : start forward" ;
266269 // (bs, seq_len, vocab_size)
267270 auto logits = model->Forward ({x, y})[0 ];
268- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish model forward, start loss forward" ;
271+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : finish model forward, start loss forward" ;
269272 auto loss = loss_fn->Forward ({logits, y})[0 ];
270273 loss = loss / grad_accum_steps;
271274
272275 // disable autocast for the current step (backward is not under autocast)
273276 autocast_guard.Disable ();
274277
275- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish loss forward" ;
278+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : finish loss forward" ;
276279
277280 auto loss_cpu = loss->To (DeviceManager::Instance ()->GetDefaultDevice ());
278281 lossf += static_cast <const float *>(loss_cpu.DataPtr ())[0 ];
279- LOG (INFO) << " Rank " << rank.thread_rank () << " : start backward" ;
282+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : start backward" ;
280283 loss->Backward ();
281- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish backward" ;
284+ LOG (INFO) << " Rank " << rank.GlobalRank () << " : finish backward" ;
282285 }
283286
284287 optimizer.Step ();
0 commit comments