From 1dc4122545d3371140d9606d278254b38104cbd2 Mon Sep 17 00:00:00 2001 From: leejianwoo-collab Date: Fri, 23 Jan 2026 02:59:30 -0500 Subject: [PATCH 1/4] fix: Ensure full gradient reduction for Muon with reduce_scatter Signed-off-by: leejianwoo-collab --- deepspeed/runtime/zero/stage_1_and_2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) mode change 100755 => 100644 deepspeed/runtime/zero/stage_1_and_2.py diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py old mode 100755 new mode 100644 index 107e47a44042..42ae3c1d3017 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1187,7 +1187,11 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if any parameter uses Muon optimizer (needs full gradient for orthogonalization) + uses_muon = any(getattr(param, 'use_muon', False) for group in self.bit16_groups for param in group) + + if not self.reduce_scatter or uses_muon: + # Force full all-reduce for Muon parameters even when reduce_scatter is enabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From a873854eef60ac7fb25fdd61203da592a18357f9 Mon Sep 17 00:00:00 2001 From: nathon Date: Sat, 24 Jan 2026 09:26:32 +0800 Subject: [PATCH 2/4] Update stage_1_and_2.py Signed-off-by: leejianwoo-collab --- deepspeed/runtime/zero/stage_1_and_2.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 42ae3c1d3017..97d097eb2163 100644 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,15 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert not self.uses_muon, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,10 +1191,8 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if any parameter uses Muon optimizer (needs full gradient for orthogonalization) - uses_muon = any(getattr(param, 'use_muon', False) for group in self.bit16_groups for param in group) - - if not self.reduce_scatter or uses_muon: + # Use pre-detected Muon flag from initialization + if not self.reduce_scatter or self.uses_muon: # Force full all-reduce for Muon parameters even when reduce_scatter is enabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From f6ddd7545d37937bf3f4716876d0d38b8b3c8018 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:08:53 -0800 Subject: [PATCH 3/4] Fix ZeRO stage to choose BF16 optimizer in test (#7803) Use ZeRO stage 1 to use BF16 optimizer. (We should have switched to ZeRO1 in #7788, but I missed the change. @sfc-gh-truwase) - #7790 removed the fallback that allowed bf16 model + fp32 grad accumulation without ZeRO, so that combo now raises NotImplementedError. - #7788 changed test_bf16_optimizer_fragments to force BF16_Optimizer by setting grad_accum_dtype=fp32, but it kept ZeRO stage 0, which is now invalid after #7790. Signed-off-by: Masahiro Tanaka Signed-off-by: leejianwoo-collab --- tests/unit/runtime/zero/test_zero_tensor_fragment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index 31d8b0990bf3..90e8e968abdf 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -179,7 +179,7 @@ def test_bf16_optimizer_fragments(self, frozen_weights): "grad_accum_dtype": "fp32" }, "zero_optimization": { - "stage": 0, + "stage": 1, } } From 15996a95a70a9d7e747de7b2d3c0784840693189 Mon Sep 17 00:00:00 2001 From: nathon Date: Sat, 24 Jan 2026 12:41:22 +0800 Subject: [PATCH 4/4] Update stage_1_and_2.py Signed-off-by: leejianwoo-collab --- deepspeed/runtime/zero/stage_1_and_2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 97d097eb2163..55ab80c6c994 100644 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -291,7 +291,10 @@ def _enforce_cpu_offload(): assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - assert not self.uses_muon, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = []