Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def _allocate_or_extend_buffers(self, idx, shape, dtype):
self._grad_layer_buf[idx] = new_buf
return self._grad_layer_buf[idx]
else:
return self._grad_layer_buf[idx].flatten()[:numel].view(shape)
return self._grad_layer_buf[idx].flatten()[:numel].view(shape).to(dtype)

def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
Expand Down
96 changes: 40 additions & 56 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace,
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, is_model_parallel_parameter, align_dense_tensors,
all_gather_dp_groups, mask_nan_or_inf_with_val_inplace,
count_used_parameters_in_backward)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
Expand Down Expand Up @@ -1420,8 +1420,11 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
self.clear_grad_attribute(param) #offload only

def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
norm_type = 2.0
"""
Compute local squared L2 norm of gradients for CPU-offloaded parameters.
No cross-rank communication is performed here.
"""
local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
for p in params:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
Expand All @@ -1434,7 +1437,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
local_sq_norm += param_norm**2
else:
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to
Expand All @@ -1447,19 +1450,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
(2) making sure all trainable parameters and `forward` function
outputs participate in calculating loss.
"""

# Sum across all model parallel GPUs.
total_dev_norm = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group)

self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM)

total_norm = total_dev_norm[0].item()**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1.0

return torch.tensor(total_norm, device=self.device, dtype=torch.float)
return local_sq_norm

############################################################################################
def copy_grads_in_partition(self, param):
Expand Down Expand Up @@ -1872,41 +1863,21 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
all_norms = []
if norm_type == inf:
for g in gradients:
all_norms.append(g.data.abs().max().float())
total_norm = torch.stack(all_norms).max()
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.dp_process_group)

# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.MAX)
else:
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
all_norms.append(
torch.linalg.vector_norm(g.data.double().detach(),
ord=norm_type).to(get_accelerator().current_device_name()))
if len(all_norms) > 0:
total_norm = torch.stack(all_norms).square().sum().float()
else:
total_norm = torch.tensor(0.0, dtype=torch.float32).to(self.device)
# Sum across all model parallel Device.
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group)
assert norm_type == 2, "only L2 norm supported"

self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM)
local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)

total_norm = total_norm.pow(1. / norm_type)
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue

mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
if g is None:
continue
local_sq_norm += torch.sum(g.data.double() * g.data.double())

return total_norm
return local_sq_norm

def get_all_grad_tensors(self, tensor_list, dtype):
all_grad_tensors = []
Expand Down Expand Up @@ -2015,19 +1986,32 @@ def override_loss_scale(self, loss_scale):

def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
norm_groups = []
for i, group in enumerate(self.bit16_groups):
local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
for i, _ in enumerate(self.bit16_groups):
if self.cpu_offload:
norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])
norm_groups.append(norm)
group_sq_norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])
else:
norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))
group_sq_norm = self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])
local_total_sq_norm += group_sq_norm

if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
self._average_expert_grad_norms(local_total_sq_norm)
Comment on lines 1997 to +1998

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve per-group norms when averaging MoE expert grads

When has_moe_layers is true, _average_expert_grad_norms expects a per-parameter-group collection (it iterates by group index and checks self.is_moe_param_group[i]), but this call now passes a single accumulated tensor. In runs with multiple optimizer groups, that means MoE scaling is applied to at most index 0 (or to the already-mixed total), so expert and non-expert contributions are mis-scaled before clipping/overflow logic, producing incorrect global grad norms.

Useful? React with 👍 / 👎.


local_total_sq_norm = local_total_sq_norm.to(torch.cuda.current_device())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid hard-coding CUDA device for norm reduction

This forces scaled_global_norm onto torch.cuda.current_device() even though the optimizer code is written against DeepSpeed's accelerator abstraction. On non-CUDA backends (or CPU-only execution), this line raises before all_reduce, so gradient norm computation and optimizer step fail outright; the tensor should stay on self.device or use get_accelerator().current_device_name().

Useful? React with 👍 / 👎.

dist.all_reduce(
local_total_sq_norm,
op=dist.ReduceOp.SUM,
group=self.dp_process_group,
)
self._model_parallel_all_reduce(
tensor=local_total_sq_norm,
op=dist.ReduceOp.SUM,
)
total_norm = torch.sqrt(local_total_sq_norm)

# calculating L2 norm
return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)
mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
return total_norm

def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
Expand Down