diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 463ab711d3cb..196d249a49c1 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1291,7 +1291,7 @@ def _allocate_or_extend_buffers(self, idx, shape, dtype): self._grad_layer_buf[idx] = new_buf return self._grad_layer_buf[idx] else: - return self._grad_layer_buf[idx].flatten()[:numel].view(shape) + return self._grad_layer_buf[idx].flatten()[:numel].view(shape).to(dtype) def forward(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index ff3d5cc953c2..65421e246498 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -19,8 +19,8 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes -from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, - align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, +from deepspeed.runtime.utils import (empty_cache, see_memory_usage, is_model_parallel_parameter, align_dense_tensors, + all_gather_dp_groups, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum @@ -1420,8 +1420,11 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): self.clear_grad_attribute(param) #offload only def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 + """ + Compute local squared L2 norm of gradients for CPU-offloaded parameters. + No cross-rank communication is performed here. + """ + local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) for p in params: # Pipeline parallelism may replicate parameters. Avoid multi-counting. if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: @@ -1434,7 +1437,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + local_sq_norm += param_norm**2 else: # As unused parameters in modules may not be expected sometimes, # add an explicit error msg when it occurred and an option to @@ -1447,19 +1450,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): (2) making sure all trainable parameters and `forward` function outputs participate in calculating loss. """ - - # Sum across all model parallel GPUs. - total_dev_norm = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM) - - total_norm = total_dev_norm[0].item()**(1. / norm_type) - - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1.0 - - return torch.tensor(total_norm, device=self.device, dtype=torch.float) + return local_sq_norm ############################################################################################ def copy_grads_in_partition(self, param): @@ -1872,41 +1863,21 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): Returns: Total norm of the parameters (viewed as a single vector). """ - norm_type = float(norm_type) - all_norms = [] - if norm_type == inf: - for g in gradients: - all_norms.append(g.data.abs().max().float()) - total_norm = torch.stack(all_norms).max() - dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.MAX) - else: - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - for g, p in zip(gradients, params): - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - all_norms.append( - torch.linalg.vector_norm(g.data.double().detach(), - ord=norm_type).to(get_accelerator().current_device_name())) - if len(all_norms) > 0: - total_norm = torch.stack(all_norms).square().sum().float() - else: - total_norm = torch.tensor(0.0, dtype=torch.float32).to(self.device) - # Sum across all model parallel Device. - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group) + assert norm_type == 2, "only L2 norm supported" - self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM) + local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) - total_norm = total_norm.pow(1. / norm_type) + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue - mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + if g is None: + continue + local_sq_norm += torch.sum(g.data.double() * g.data.double()) - return total_norm + return local_sq_norm def get_all_grad_tensors(self, tensor_list, dtype): all_grad_tensors = [] @@ -2015,19 +1986,32 @@ def override_loss_scale(self, loss_scale): def scaled_global_norm(self, norm_type=2): assert norm_type == 2, "only L2 norm supported" - norm_groups = [] - for i, group in enumerate(self.bit16_groups): + local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype) + for i, _ in enumerate(self.bit16_groups): if self.cpu_offload: - norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) - norm_groups.append(norm) + group_sq_norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]) else: - norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) + group_sq_norm = self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]) + local_total_sq_norm += group_sq_norm if self.has_moe_layers: - self._average_expert_grad_norms(norm_groups) + self._average_expert_grad_norms(local_total_sq_norm) + + local_total_sq_norm = local_total_sq_norm.to(torch.cuda.current_device()) + dist.all_reduce( + local_total_sq_norm, + op=dist.ReduceOp.SUM, + group=self.dp_process_group, + ) + self._model_parallel_all_reduce( + tensor=local_total_sq_norm, + op=dist.ReduceOp.SUM, + ) + total_norm = torch.sqrt(local_total_sq_norm) # calculating L2 norm - return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type) + mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) + return total_norm def get_bit16_param_group(self, group_no): bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]