From 1a693b52964c53883f954f612c4d1996e006bed7 Mon Sep 17 00:00:00 2001 From: fy817 <277645218@qq.com> Date: Mon, 12 Jan 2026 19:38:45 +0800 Subject: [PATCH] Fix Muon optimizer conflict with gradient clipping in ZeRO 1/2 Signed-off-by: fy817 <277645218@qq.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 56 ++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 86f50a1a0c0b..7bf2236bd7b1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -860,6 +860,32 @@ def independent_gradient_partition_epilogue(self): self._clear_previous_reduced_grads() if self.cpu_offload is False: + # Pre-compute gradient norm for Muon clipping if needed + grad_norm_for_muon = None + if self.is_gradient_accumulation_boundary: + # Check if any parameter group uses Muon + uses_muon = False + for i, _ in enumerate(self.bit16_groups): + if len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False): + uses_muon = True + break + + # Compute unscaled gradient norm if Muon is used and clipping is enabled + if uses_muon and self.clip_grad > 0.: + # Compute gradient norm before Muon update + norm_groups = [] + for i, group in enumerate(self.bit16_groups): + if not i in self.averaged_gradients or self.averaged_gradients[i] is None: + all_grad_tensors = self.get_all_grad_tensors(self.params_in_partition[i], + dtype=self.gradient_accumulation_dtype) + else: + all_grad_tensors = self.all_grad_tensors[i] + if all_grad_tensors is not None: + norm_groups.append(self.get_grad_norm_direct(all_grad_tensors, self.params_in_partition[i])) + + if len(norm_groups) > 0: + grad_norm_for_muon = torch.linalg.vector_norm(torch.stack(norm_groups), ord=2) + for i, _ in enumerate(self.bit16_groups): if i not in self.all_grad_tensors or self.all_grad_tensors[i] is None: self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i], @@ -877,7 +903,8 @@ def independent_gradient_partition_epilogue(self): dtype=self.gradient_accumulation_dtype, device=get_accelerator().current_device_name(), param_group_idx=i, - return_tensor_list=True) + return_tensor_list=True, + grad_norm=grad_norm_for_muon) self.all_grad_tensors[i] = None self._release_ipg_buffers() @@ -1894,7 +1921,8 @@ def get_flat_partition(self, dtype, device, param_group_idx, - return_tensor_list=False): + return_tensor_list=False, + grad_norm=None): if len(tensor_list) == 0: # This condition can fire when we have small parameteters and many ranks. zero_buffer = torch.zeros(int(partition_size), dtype=dtype, device=device) @@ -1916,11 +1944,22 @@ def get_flat_partition(self, flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)] self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list) + # Calculate clip factor if gradient clipping is enabled and grad_norm is provided + clip_factor = 1.0 + if self.clip_grad > 0. and grad_norm is not None: + # grad_norm is already unscaled (divided by loss_scale) + clip_factor = max(1.0, grad_norm / self.clip_grad) + buffer_idx = 0 for i, tensor in enumerate(tensor_list): grad_accum = self.all_grad_tensors[param_group_idx][i] if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" + + # Apply gradient clipping before muon_update + if clip_factor > 1.0: + grad_accum = grad_accum / clip_factor + buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, tensor.numel()).view(tensor.size()) grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum']) @@ -2058,15 +2097,20 @@ def step(self, closure=None): see_memory_usage('Before norm calculation') scaled_global_grad_norm = self.scaled_global_norm() self._global_grad_norm = scaled_global_grad_norm / prev_scale + unscaled_grad_norm = self._global_grad_norm # Store unscaled norm for use in get_flat_partition see_memory_usage('After norm before optimizer') # Step 2:- run optimizer and upscaling simultaneously for i, group in enumerate(self.bit16_groups): self.timers(OPTIMIZER_GRADIENTS_TIMER).start() partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + + # Check if this param group uses Muon (clipping already done in get_flat_partition) + uses_muon = len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False) + if self.cpu_offload: single_grad_partition = self.single_partition_of_fp32_groups[i].grad - self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon) self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() self.timers(OPTIMIZER_STEP_TIMER).start() @@ -2108,7 +2152,7 @@ def step(self, closure=None): self.averaged_gradients[i] = None self.all_grad_tensors[i] = None - self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon) self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() @@ -2168,10 +2212,10 @@ def _average_expert_grad_norms(self, norm_groups): dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) norm_groups[i] = scaled_norm_tensor.to(self.device) - def unscale_and_clip_grads(self, grad_groups_flat, total_norm): + def unscale_and_clip_grads(self, grad_groups_flat, total_norm, skip_clipping=False): # compute combined scale factor for this group combined_scale = self.loss_scale - if self.clip_grad > 0.: + if self.clip_grad > 0. and not skip_clipping: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad clip = torch.clamp(clip, min=1.0)