From df555af9816882e7e570e1f908c2b26205e26b96 Mon Sep 17 00:00:00 2001 From: AlAuAu <458134681@qq.com> Date: Fri, 23 Jan 2026 11:23:44 +0800 Subject: [PATCH] fix vpp overlap precision with norm_topk_prob --- src/paddlefleet/transformer/moe/moe_layer.py | 12 +- .../transformer/moe/token_dispatcher.py | 9 +- .../transformer/transformer_layer.py | 116 +++++++++++++----- 3 files changed, 98 insertions(+), 39 deletions(-) diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index edb543d1a..42d7d4529 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -533,17 +533,19 @@ def compute_gate(self, hidden_states): return self.gate(hidden_states) def dispatch_preprocess(self, args): - hidden_states, token_probs, token_indices = args + hidden_states, token_indices, token_weights, gates_masked, mask = args assert isinstance(self.token_dispatcher, MoEFlexTokenDispatcher) - hidden_states = self.token_dispatcher.dispatch_preprocess_overlap( - hidden_states, token_probs, token_indices + hidden_states, token_indices, token_weights, gates_masked, mask = ( + self.token_dispatcher.dispatch_preprocess_overlap( + hidden_states, token_indices, token_weights, gates_masked, mask + ) ) token_probs = self.token_dispatcher._comm_manager.token_probs token_indices = self.token_dispatcher._comm_manager.token_indices - return hidden_states, token_indices, token_probs + return hidden_states, token_indices, token_probs, gates_masked, mask def compute_dispatch(self, args, async_finish=False): - hidden_states, token_indices, token_weights = args + hidden_states, token_indices, token_weights, gates_masked, mask = args if self.moe_use_fusion_node: dispatched_hidden_states, fp8_dispatched_handle = ( self.token_dispatcher.token_dispatch_overlap( diff --git a/src/paddlefleet/transformer/moe/token_dispatcher.py b/src/paddlefleet/transformer/moe/token_dispatcher.py index abd0595da..52b72d99a 100644 --- a/src/paddlefleet/transformer/moe/token_dispatcher.py +++ b/src/paddlefleet/transformer/moe/token_dispatcher.py @@ -380,14 +380,15 @@ def dispatch_preprocess( def dispatch_preprocess_overlap( self, hidden_states: paddle.Tensor, - token_probs: paddle.Tensor, token_indices: paddle.Tensor, + token_weights: paddle.Tensor, + probs: paddle.Tensor, + routing_map: paddle.Tensor, ): self.hidden_shape = hidden_states.shape hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) - self._comm_manager.token_probs = token_probs - self._comm_manager.token_indices = token_indices - return hidden_states + self._comm_manager.setup_metadata(routing_map, probs) + return hidden_states, token_indices, token_weights, probs, routing_map def token_dispatch_overlap( self, diff --git a/src/paddlefleet/transformer/transformer_layer.py b/src/paddlefleet/transformer/transformer_layer.py index 5133d41de..078c988c6 100644 --- a/src/paddlefleet/transformer/transformer_layer.py +++ b/src/paddlefleet/transformer/transformer_layer.py @@ -656,18 +656,26 @@ def pre_process_compute(self, hidden_states): residuals, topk_weights, topk_indices, + gates_masked, + mask, aux_loss, ) def dispatch_preprocess_compute(self, args): - hidden_states, topk_weights, topk_indices = args + hidden_states, token_indices, token_weights, gates_masked, mask = args - hidden_states, token_indices, token_weights = ( + hidden_states, token_indices, token_weights, gates_masked, mask = ( self.mlp.dispatch_preprocess( - (hidden_states, topk_weights, topk_indices) + ( + hidden_states, + token_indices, + token_weights, + gates_masked, + mask, + ) ) ) - return hidden_states, token_indices, token_weights + return hidden_states, token_indices, token_weights, gates_masked, mask def post_process_compute(self, args, is_first_fwd=False): mlp_output, residual = args @@ -743,19 +751,33 @@ def forward(self, inputs): residual, hidden_states, residuals, - topk_weights, - topk_indices, + token_weights, + token_indices, + gates_masked, + mask, aux_loss, ) = self.pre_process_node.forward(hidden_states) - hidden_states, token_indices, token_weights = ( + hidden_states, token_indices, token_weights, gates_masked, mask = ( self.dispatch_preprocess_node.forward( - (hidden_states, topk_weights, topk_indices) + ( + hidden_states, + token_indices, + token_weights, + gates_masked, + mask, + ) ) ) hidden_states = self.dispatch_node.forward( - (hidden_states, token_indices, token_weights), + ( + hidden_states, + token_indices, + token_weights, + gates_masked, + mask, + ), async_finish=True, ) dispatch_fw_event = deep_ep.get_event_from_comm_stream( @@ -832,9 +854,12 @@ def backward(self, output_grad): combine_bw_event.calc_stream_wait(self.group_id) output_grad = self.mlp_node.backward(output_grad) - (output_grad, token_indices_grad, token_weights_grad) = ( - self.dispatch_node.backward(output_grad) - ) + ( + output_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, + ) = self.dispatch_node.backward(output_grad) dispatch_bw_event = deep_ep.get_event_from_comm_stream( self.group_id ) @@ -842,10 +867,16 @@ def backward(self, output_grad): ( output_grad, - topk_weights_grad, - topk_indices_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, ) = self.dispatch_preprocess_node.backward( - (output_grad, token_indices_grad, token_weights_grad) + ( + output_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, + ) ) output_grad = self.pre_process_node.backward( @@ -853,8 +884,9 @@ def backward(self, output_grad): residual_grad, output_grad, residuals_grad, - topk_weights_grad, - topk_indices_grad, + token_weights_grad, + token_indices_grad, + gates_masked_grad, aux_loss_grad, ) ) @@ -943,20 +975,34 @@ def forward_backward(self, inputs, output_grad, split_bw=False): residual, hidden_states, residuals, - topk_weights, - topk_indices, + token_weights, + token_indices, + gates_masked, + mask, aux_loss, ) = self.forward_node.pre_process_node.forward(hidden_states) - hidden_states, token_indices, token_weights = ( + hidden_states, token_indices, token_weights, gates_masked, mask = ( self.forward_node.dispatch_preprocess_node.forward( - (hidden_states, topk_weights, topk_indices) + ( + hidden_states, + token_indices, + token_weights, + gates_masked, + mask, + ) ) ) # 4. DISPATCH(F) hidden_states = self.forward_node.dispatch_node.forward( - (hidden_states, token_indices, token_weights), + ( + hidden_states, + token_indices, + token_weights, + gates_masked, + mask, + ), async_finish=True, ) dispatch_fw_event = deep_ep.get_event_from_comm_stream( @@ -968,9 +1014,12 @@ def forward_backward(self, inputs, output_grad, split_bw=False): output_grad = self.backward_node.mlp_node.backward(output_grad) # 6. DISPATCH(B) - output_grad, token_indices_grad, token_weights_grad = ( - self.backward_node.dispatch_node.backward(output_grad) - ) + ( + output_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, + ) = self.backward_node.dispatch_node.backward(output_grad) dispatch_bw_event = deep_ep.get_event_from_comm_stream( self.backward_node.group_id ) @@ -996,10 +1045,16 @@ def forward_backward(self, inputs, output_grad, split_bw=False): dispatch_bw_event.calc_stream_wait(self.backward_node.group_id) ( output_grad, - topk_weights_grad, - topk_indices_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, ) = self.backward_node.dispatch_preprocess_node.backward( - (output_grad, token_indices_grad, token_weights_grad) + ( + output_grad, + token_indices_grad, + token_weights_grad, + gates_masked_grad, + ) ) output_grad = self.backward_node.pre_process_node.backward( @@ -1007,8 +1062,9 @@ def forward_backward(self, inputs, output_grad, split_bw=False): residual_grad, output_grad, residuals_grad, - topk_weights_grad, - topk_indices_grad, + token_weights_grad, + token_indices_grad, + gates_masked_grad, aux_loss_grad, ) )