-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Fix Muon optimizer conflict with gradient clipping in ZeRO 1/2 #7776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1a693b5
aa13d03
1e8b95c
2ce3503
b9d09ad
fad0bed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -859,6 +859,32 @@ def independent_gradient_partition_epilogue(self): | |
| self._clear_previous_reduced_grads() | ||
|
|
||
| if self.cpu_offload is False: | ||
| # Pre-compute gradient norm for Muon clipping if needed | ||
| grad_norm_for_muon = None | ||
| if self.is_gradient_accumulation_boundary: | ||
| # Check if any parameter group uses Muon | ||
| uses_muon = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main connection to #7808 is that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fy817 Like @sfc-gh-truwase mentioned, |
||
| for i, _ in enumerate(self.bit16_groups): | ||
| if len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fy817 |
||
| uses_muon = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fy817 One suggestion is that add validation in |
||
| break | ||
|
|
||
| # Compute unscaled gradient norm if Muon is used and clipping is enabled | ||
| if uses_muon and self.clip_grad > 0.: | ||
| # Compute gradient norm before Muon update | ||
| norm_groups = [] | ||
| for i, group in enumerate(self.bit16_groups): | ||
| if not i in self.averaged_gradients or self.averaged_gradients[i] is None: | ||
| all_grad_tensors = self.get_all_grad_tensors(self.params_in_partition[i], | ||
| dtype=self.gradient_accumulation_dtype) | ||
| else: | ||
| all_grad_tensors = self.all_grad_tensors[i] | ||
| if all_grad_tensors is not None: | ||
| norm_groups.append(self.get_grad_norm_direct(all_grad_tensors, self.params_in_partition[i])) | ||
|
|
||
| if len(norm_groups) > 0: | ||
| grad_norm_for_muon = torch.linalg.vector_norm(torch.stack(norm_groups), ord=2) | ||
|
|
||
| for i, _ in enumerate(self.bit16_groups): | ||
| if i not in self.all_grad_tensors or self.all_grad_tensors[i] is None: | ||
| self.all_grad_tensors[i] = self.get_all_grad_tensors(self.params_in_partition[i], | ||
|
|
@@ -876,6 +902,8 @@ def independent_gradient_partition_epilogue(self): | |
| dtype=self.gradient_accumulation_dtype, | ||
| device=get_accelerator().current_device_name(), | ||
| param_group_idx=i, | ||
| return_tensor_list=True, | ||
| grad_norm=grad_norm_for_muon) | ||
| return_tensor_list=True) | ||
| # Clear all_grad_tensors after use. With reentrant checkpointing, | ||
| # the epilogue may run multiple times per backward pass. Each time, | ||
|
|
@@ -1925,7 +1953,8 @@ def get_flat_partition(self, | |
| dtype, | ||
| device, | ||
| param_group_idx, | ||
| return_tensor_list=False): | ||
| return_tensor_list=False, | ||
| grad_norm=None): | ||
| if len(tensor_list) == 0: | ||
| # This condition can fire when we have small parameteters and many ranks. | ||
| zero_buffer = torch.zeros(int(partition_size), dtype=dtype, device=device) | ||
|
|
@@ -1947,11 +1976,22 @@ def get_flat_partition(self, | |
| flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)] | ||
| self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list) | ||
|
|
||
| # Calculate clip factor if gradient clipping is enabled and grad_norm is provided | ||
| clip_factor = 1.0 | ||
| if self.clip_grad > 0. and grad_norm is not None: | ||
| # grad_norm is already unscaled (divided by loss_scale) | ||
| clip_factor = max(1.0, grad_norm / self.clip_grad) | ||
|
|
||
| buffer_idx = 0 | ||
| for i, tensor in enumerate(tensor_list): | ||
| grad_accum = self.all_grad_tensors[param_group_idx][i] | ||
| if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): | ||
| assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" | ||
|
|
||
| # Apply gradient clipping before muon_update | ||
| if clip_factor > 1.0: | ||
| grad_accum = grad_accum / clip_factor | ||
|
|
||
| buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, | ||
| tensor.numel()).view(tensor.size()) | ||
| grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum']) | ||
|
|
@@ -2089,15 +2129,20 @@ def step(self, closure=None): | |
| see_memory_usage('Before norm calculation') | ||
| scaled_global_grad_norm = self.scaled_global_norm() | ||
| self._global_grad_norm = scaled_global_grad_norm / prev_scale | ||
| unscaled_grad_norm = self._global_grad_norm # Store unscaled norm for use in get_flat_partition | ||
| see_memory_usage('After norm before optimizer') | ||
|
|
||
| # Step 2:- run optimizer and upscaling simultaneously | ||
| for i, group in enumerate(self.bit16_groups): | ||
| self.timers(OPTIMIZER_GRADIENTS_TIMER).start() | ||
| partition_id = dist.get_rank(group=self.real_dp_process_group[i]) | ||
|
|
||
| # Check if this param group uses Muon (clipping already done in get_flat_partition) | ||
| uses_muon = len(self.params_in_partition[i]) > 0 and getattr(self.params_in_partition[i][0], 'use_muon', False) | ||
|
|
||
| if self.cpu_offload: | ||
| single_grad_partition = self.single_partition_of_fp32_groups[i].grad | ||
| self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) | ||
| self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon) | ||
|
|
||
| self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() | ||
| self.timers(OPTIMIZER_STEP_TIMER).start() | ||
|
|
@@ -2139,7 +2184,7 @@ def step(self, closure=None): | |
|
|
||
| self.averaged_gradients[i] = None | ||
| self.all_grad_tensors[i] = None | ||
| self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) | ||
| self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm, skip_clipping=uses_muon) | ||
|
|
||
| self.timers(OPTIMIZER_GRADIENTS_TIMER).stop() | ||
|
|
||
|
|
@@ -2199,10 +2244,10 @@ def _average_expert_grad_norms(self, norm_groups): | |
| dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) | ||
| norm_groups[i] = scaled_norm_tensor.to(self.device) | ||
|
|
||
| def unscale_and_clip_grads(self, grad_groups_flat, total_norm): | ||
| def unscale_and_clip_grads(self, grad_groups_flat, total_norm, skip_clipping=False): | ||
| # compute combined scale factor for this group | ||
| combined_scale = self.loss_scale | ||
| if self.clip_grad > 0.: | ||
| if self.clip_grad > 0. and not skip_clipping: | ||
| # norm is in fact norm*scale | ||
| clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad | ||
| clip = torch.clamp(clip, min=1.0) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you introduced a new instance variable
grad_norm_for_muonthat's set inindependent_gradient_partition_epiloguefunction and consumed inget_flat_partitionfunction. This would potentially cause data racing problem due to the resulting implicit dependency on the call orders. One solution is that we make thegrad_norm_for_muondefined explicitly outside the function, then we can ensure that it's always valid before being passed into functions.