Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,32 @@ def independent_gradient_partition_epilogue(self):
self._clear_previous_reduced_grads()

if self.cpu_offload is False:
# Pre-compute gradient norm for Muon clipping if needed
grad_norm_for_muon = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you introduced a new instance variable grad_norm_for_muon that's set in independent_gradient_partition_epilogue function and consumed in get_flat_partition function. This would potentially cause data racing problem due to the resulting implicit dependency on the call orders. One solution is that we make the grad_norm_for_muon defined explicitly outside the function, then we can ensure that it's always valid before being passed into functions.

if self.is_gradient_accumulation_boundary:
# Check if any parameter group uses Muon
uses_muon = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main connection to #7808 is that use_muon should be an object field set in the constructor and defined per-param_group.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fy817 Like @sfc-gh-truwase mentioned, use_muon should be an object field. This should be computed once and stored as a class field. Computing this repeatedly is inefficiency and can cause errors.

for i, _ in enumerate(self.bit16_groups):
if len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fy817 getattr(tensor_list[0], 'use_muon', False) needs to be set externally by the user. Otherwise there's no documentation and validation that this attribute would exist or be set properly.

uses_muon = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fy817 One suggestion is that add validation in __init__ of the DeepSpeedZeroOptimizer class that checks for use_muon attributes when a Muon optimizer is detected, and add documentation about this requirement.

break

# Compute unscaled gradient norm if Muon is used and clipping is enabled
if uses_muon and self.clip_grad > 0.:
# Compute gradient norm before Muon update
norm_groups = []
for i, group in enumerate(self.bit16_groups):
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
all_grad_tensors = self.get_all_grad_tensors(self.params_in_partition[i],
dtype=self.gradient_accumulation_dtype)
else:
all_grad_tensors = self.all_grad_tensors[i]
if all_grad_tensors is not None:
norm_groups.append(self.get_grad_norm_direct(all_grad_tensors, self.params_in_partition[i]))

if len(norm_groups) > 0:
grad_norm_for_muon = torch.linalg.vector_norm(torch.stack(norm_groups), ord=2)

for i, _ in enumerate(self.bit16_groups):
if i not in self.all_grad_tensors or self.all_grad_tensors[i] is None:
self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i],
Expand All @@ -876,6 +902,8 @@ def independent_gradient_partition_epilogue(self):
dtype=self.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
param_group_idx=i,
return_tensor_list=True,
grad_norm=grad_norm_for_muon)
return_tensor_list=True)
# Clear all_grad_tensors after use. With reentrant checkpointing,
# the epilogue may run multiple times per backward pass. Each time,
Expand Down Expand Up @@ -1925,7 +1953,8 @@ def get_flat_partition(self,
dtype,
device,
param_group_idx,
return_tensor_list=False):
return_tensor_list=False,
grad_norm=None):
if len(tensor_list) == 0:
# This condition can fire when we have small parameteters and many ranks.
zero_buffer = torch.zeros(int(partition_size), dtype=dtype, device=device)
Expand All @@ -1947,11 +1976,22 @@ def get_flat_partition(self,
flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)]
self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list)

# Calculate clip factor if gradient clipping is enabled and grad_norm is provided
clip_factor = 1.0
if self.clip_grad > 0. and grad_norm is not None:
# grad_norm is already unscaled (divided by loss_scale)
clip_factor = max(1.0, grad_norm / self.clip_grad)

buffer_idx = 0
for i, tensor in enumerate(tensor_list):
grad_accum = self.all_grad_tensors[param_group_idx][i]
if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"

# Apply gradient clipping before muon_update
if clip_factor > 1.0:
grad_accum = grad_accum / clip_factor

buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size())
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
Expand Down Expand Up @@ -2089,15 +2129,20 @@ def step(self, closure=None):
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
unscaled_grad_norm = self._global_grad_norm # Store unscaled norm for use in get_flat_partition
see_memory_usage('After norm before optimizer')

# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.timers(OPTIMIZER_GRADIENTS_TIMER).start()
partition_id = dist.get_rank(group=self.real_dp_process_group[i])

# Check if this param group uses Muon (clipping already done in get_flat_partition)
uses_muon = len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False)

if self.cpu_offload:
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon)

self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()
self.timers(OPTIMIZER_STEP_TIMER).start()
Expand Down Expand Up @@ -2139,7 +2184,7 @@ def step(self, closure=None):

self.averaged_gradients[i] = None
self.all_grad_tensors[i] = None
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon)

self.timers(OPTIMIZER_GRADIENTS_TIMER).stop()

Expand Down Expand Up @@ -2199,10 +2244,10 @@ def _average_expert_grad_norms(self, norm_groups):
dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i])
norm_groups[i] = scaled_norm_tensor.to(self.device)

def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
def unscale_and_clip_grads(self, grad_groups_flat, total_norm, skip_clipping=False):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
if self.clip_grad > 0. and not skip_clipping:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
clip = torch.clamp(clip, min=1.0)
Expand Down
Loading